diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java index 6976f884cf4c3686b2f3531b5f7c86514ee4342e..10853a996c8aee8f4c672122aeb90f4cee0f4ebd 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java @@ -31,6 +31,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.object.DynamicObject; @@ -38,6 +39,7 @@ import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.GetAttributeNode; import com.oracle.truffle.r.nodes.attributes.IterableAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetRowNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.UpdateSharedAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; @@ -61,7 +63,6 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractVector; public abstract class Attr extends RBuiltinNode { private final ConditionProfile searchPartialProfile = ConditionProfile.createBinaryProfile(); - private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); private final BranchProfile errorProfile = BranchProfile.create(); @CompilationFinal private String cachedName = ""; @@ -145,13 +146,14 @@ public abstract class Attr extends RBuiltinNode { } @Specialization(guards = "isRowNamesAttr(name)") - protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact) { + protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact, + @Cached("create()") GetRowNamesAttributeNode getRowNamesNode) { // TODO: if exact == false, check for partial match (there is an ignored tests for it) DynamicObject attributes = container.getAttributes(); if (attributes == null) { return RNull.instance; } else { - return getFullRowNames(container.getRowNames(attrProfiles)); + return getFullRowNames(getRowNamesNode.getRowNames(container)); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ShortRowNames.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ShortRowNames.java index 0676f6ba7a00f0dfb1b9e1480e269abf701ee2a9..0ea914972ddd148f9d8fbdc952edf6d85898de32 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ShortRowNames.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ShortRowNames.java @@ -28,11 +28,13 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.IntValueProfile; import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.nodes.attributes.GetFixedAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetRowNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; @@ -48,7 +50,6 @@ import com.oracle.truffle.r.runtime.env.REnvironment; public abstract class ShortRowNames extends RBuiltinNode { private final BranchProfile naValueMet = BranchProfile.create(); - private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); private final BranchProfile errorProfile = BranchProfile.create(); private final ValueProfile operandTypeProfile = ValueProfile.createClassProfile(); @@ -62,11 +63,12 @@ public abstract class ShortRowNames extends RBuiltinNode { private final IntValueProfile typeProfile = IntValueProfile.createIdentityProfile(); @Specialization - protected Object getNames(Object originalOperand, int originalType) { + protected Object getNames(Object originalOperand, int originalType, + @Cached("create()") GetRowNamesAttributeNode getRowNamesNode) { Object operand = operandTypeProfile.profile(originalOperand); Object rowNames; if (operand instanceof RAbstractContainer) { - rowNames = ((RAbstractContainer) operand).getRowNames(attrProfiles); + rowNames = getRowNamesNode.getRowNames((RAbstractContainer) operand); } else if (operand instanceof REnvironment) { if (getRowNamesAttrNode == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttr.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttr.java index abc2bdfa1636f5ea9a29fceca1e3c76634c47fb6..34c1d680f8290eead7dfaabaed28894cabb35bd4 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttr.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttr.java @@ -163,7 +163,11 @@ public abstract class UpdateAttr extends RBuiltinNode { setClassAttrNode.reset(result); return result; } else if (internedName == RRuntime.ROWNAMES_ATTR_KEY) { - result.setRowNames(null); + if (setRowNamesAttrNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + setRowNamesAttrNode = insert(SetRowNamesAttributeNode.create()); + } + setRowNamesAttrNode.setRowNames(result, null); } else if (result.getAttributes() != null) { result.removeAttr(attrProfiles, internedName); } @@ -213,7 +217,7 @@ public abstract class UpdateAttr extends RBuiltinNode { CompilerDirectives.transferToInterpreterAndInvalidate(); setRowNamesAttrNode = insert(SetRowNamesAttributeNode.create()); } - setRowNamesAttrNode.execute(result, castVector(value)); + setRowNamesAttrNode.setRowNames(result, castVector(value)); } else { // generic attribute if (setGenAttrNode == null) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java index 9bd6a9ed4aa25415dba968acfcde5aa5bcc6c39f..4e0c99102cbac929b1bff3705111ca5f2e9f4f43 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java @@ -34,6 +34,7 @@ import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SetAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetRowNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.unary.CastIntegerNode; @@ -66,6 +67,7 @@ public abstract class UpdateAttributes extends RBuiltinNode { @Child private CastToVectorNode castVector; @Child private SetAttributeNode setAttrNode; @Child private SetDimAttributeNode setDimNode; + @Child private SetRowNamesAttributeNode setRowNamesNode; @Override protected void createCasts(CastBuilder casts) { @@ -208,7 +210,11 @@ public abstract class UpdateAttributes extends RBuiltinNode { } res = result; } else if (attrName.equals(RRuntime.ROWNAMES_ATTR_KEY)) { - res.setRowNames(castVector(value)); + if (setRowNamesNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + setRowNamesNode = insert(SetRowNamesAttributeNode.create()); + } + setRowNamesNode.setRowNames(res, castVector(value)); } else { if (value == RNull.instance) { res.removeAttr(attrProfiles, attrName); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java index c55c482899d8291b684b18de0d960b5918055681..d064aeb7d7548c7975134f5cc5e1ce03aac6df99 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java @@ -184,9 +184,9 @@ public final class SpecialAttributesFunctions { } else if (name == RRuntime.DIM_ATTR_KEY) { return SetDimAttributeNode.create(); } else if (name == RRuntime.DIMNAMES_ATTR_KEY) { - return SpecialAttributesFunctions.SetDimNamesAttributeNode.create(); + return SetDimNamesAttributeNode.create(); } else if (name == RRuntime.ROWNAMES_ATTR_KEY) { - return SpecialAttributesFunctions.SetRowNamesAttributeNode.create(); + return SetRowNamesAttributeNode.create(); } else if (name == RRuntime.CLASS_ATTR_KEY) { throw RInternalError.unimplemented("The \"class\" attribute should be set using a separate method"); } else { @@ -766,6 +766,8 @@ public final class SpecialAttributesFunctions { public abstract static class SetRowNamesAttributeNode extends SetSpecialAttributeNode { + private final ConditionProfile nullRowNamesProfile = ConditionProfile.createBinaryProfile(); + protected SetRowNamesAttributeNode() { super(RRuntime.ROWNAMES_ATTR_KEY); } @@ -774,6 +776,20 @@ public final class SpecialAttributesFunctions { return SpecialAttributesFunctionsFactory.SetRowNamesAttributeNodeGen.create(); } + public void setRowNames(RAbstractContainer x, RAbstractVector rowNames) { + if (nullRowNamesProfile.profile(rowNames == null)) { + execute(x, RNull.instance); + } else { + execute(x, rowNames); + } + } + + @Specialization(insertBefore = "setAttrInAttributable") + protected void resetRowNames(RAbstractContainer x, @SuppressWarnings("unused") RNull rnull, + @Cached("create()") RemoveRowNamesAttributeNode removeRowNamesAttrNode) { + removeRowNamesAttrNode.execute(x); + } + @Specialization(insertBefore = "setAttrInAttributable") protected void setRowNamesInContainer(RAbstractContainer x, RAbstractVector rowNames, @Cached("createClassProfile()") ValueProfile contClassProfile) { RAbstractContainer xProfiled = contClassProfile.profile(x); @@ -809,12 +825,28 @@ public final class SpecialAttributesFunctions { return SpecialAttributesFunctionsFactory.GetRowNamesAttributeNodeGen.create(); } + public Object getRowNames(RAbstractContainer x) { + return execute(x); + } + + @Specialization(insertBefore = "getAttrFromAttributable") + protected Object getScalarVectorRowNames(@SuppressWarnings("unused") RScalarVector x) { + return RNull.instance; + } + + @Specialization(insertBefore = "getAttrFromAttributable") + protected Object getScalarVectorRowNames(@SuppressWarnings("unused") RSequence x) { + return RNull.class; + } + @Specialization(insertBefore = "getAttrFromAttributable") - protected Object getVectorRowNames(RVector<?> x, + protected Object getVectorRowNames(RAbstractVector x, @Cached("create()") BranchProfile attrNullProfile, @Cached("createBinaryProfile()") ConditionProfile attrStorageProfile, - @Cached("createClassProfile()") ValueProfile xTypeProfile) { - return super.getAttrFromAttributable(x, attrNullProfile, attrStorageProfile, xTypeProfile); + @Cached("createClassProfile()") ValueProfile xTypeProfile, + @Cached("createBinaryProfile()") ConditionProfile nullRowNamesProfile) { + Object res = super.getAttrFromAttributable(x, attrNullProfile, attrStorageProfile, xTypeProfile); + return nullRowNamesProfile.profile(res == null) ? RNull.instance : res; } @Specialization(insertBefore = "getAttrFromAttributable")