Skip to content
Snippets Groups Projects
Commit 299d05c4 authored by Zbynek Slajchrt's avatar Zbynek Slajchrt
Browse files

get/setRowNames occurences replaced by node invocations

parent 02b6232e
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.object.DynamicObject;
......@@ -38,6 +39,7 @@ import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.attributes.GetAttributeNode;
import com.oracle.truffle.r.nodes.attributes.IterableAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetRowNamesAttributeNode;
import com.oracle.truffle.r.nodes.attributes.UpdateSharedAttributeNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
......@@ -61,7 +63,6 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
public abstract class Attr extends RBuiltinNode {
private final ConditionProfile searchPartialProfile = ConditionProfile.createBinaryProfile();
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
private final BranchProfile errorProfile = BranchProfile.create();
@CompilationFinal private String cachedName = "";
......@@ -145,13 +146,14 @@ public abstract class Attr extends RBuiltinNode {
}
@Specialization(guards = "isRowNamesAttr(name)")
protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact) {
protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact,
@Cached("create()") GetRowNamesAttributeNode getRowNamesNode) {
// TODO: if exact == false, check for partial match (there is an ignored tests for it)
DynamicObject attributes = container.getAttributes();
if (attributes == null) {
return RNull.instance;
} else {
return getFullRowNames(container.getRowNames(attrProfiles));
return getFullRowNames(getRowNamesNode.getRowNames(container));
}
}
......
......@@ -28,11 +28,13 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.IntValueProfile;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.attributes.GetFixedAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetRowNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RError;
......@@ -48,7 +50,6 @@ import com.oracle.truffle.r.runtime.env.REnvironment;
public abstract class ShortRowNames extends RBuiltinNode {
private final BranchProfile naValueMet = BranchProfile.create();
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
private final BranchProfile errorProfile = BranchProfile.create();
private final ValueProfile operandTypeProfile = ValueProfile.createClassProfile();
......@@ -62,11 +63,12 @@ public abstract class ShortRowNames extends RBuiltinNode {
private final IntValueProfile typeProfile = IntValueProfile.createIdentityProfile();
@Specialization
protected Object getNames(Object originalOperand, int originalType) {
protected Object getNames(Object originalOperand, int originalType,
@Cached("create()") GetRowNamesAttributeNode getRowNamesNode) {
Object operand = operandTypeProfile.profile(originalOperand);
Object rowNames;
if (operand instanceof RAbstractContainer) {
rowNames = ((RAbstractContainer) operand).getRowNames(attrProfiles);
rowNames = getRowNamesNode.getRowNames((RAbstractContainer) operand);
} else if (operand instanceof REnvironment) {
if (getRowNamesAttrNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
......
......@@ -163,7 +163,11 @@ public abstract class UpdateAttr extends RBuiltinNode {
setClassAttrNode.reset(result);
return result;
} else if (internedName == RRuntime.ROWNAMES_ATTR_KEY) {
result.setRowNames(null);
if (setRowNamesAttrNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
setRowNamesAttrNode = insert(SetRowNamesAttributeNode.create());
}
setRowNamesAttrNode.setRowNames(result, null);
} else if (result.getAttributes() != null) {
result.removeAttr(attrProfiles, internedName);
}
......@@ -213,7 +217,7 @@ public abstract class UpdateAttr extends RBuiltinNode {
CompilerDirectives.transferToInterpreterAndInvalidate();
setRowNamesAttrNode = insert(SetRowNamesAttributeNode.create());
}
setRowNamesAttrNode.execute(result, castVector(value));
setRowNamesAttrNode.setRowNames(result, castVector(value));
} else {
// generic attribute
if (setGenAttrNode == null) {
......
......@@ -34,6 +34,7 @@ import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.attributes.SetAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetRowNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.unary.CastIntegerNode;
......@@ -66,6 +67,7 @@ public abstract class UpdateAttributes extends RBuiltinNode {
@Child private CastToVectorNode castVector;
@Child private SetAttributeNode setAttrNode;
@Child private SetDimAttributeNode setDimNode;
@Child private SetRowNamesAttributeNode setRowNamesNode;
@Override
protected void createCasts(CastBuilder casts) {
......@@ -208,7 +210,11 @@ public abstract class UpdateAttributes extends RBuiltinNode {
}
res = result;
} else if (attrName.equals(RRuntime.ROWNAMES_ATTR_KEY)) {
res.setRowNames(castVector(value));
if (setRowNamesNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
setRowNamesNode = insert(SetRowNamesAttributeNode.create());
}
setRowNamesNode.setRowNames(res, castVector(value));
} else {
if (value == RNull.instance) {
res.removeAttr(attrProfiles, attrName);
......
......@@ -184,9 +184,9 @@ public final class SpecialAttributesFunctions {
} else if (name == RRuntime.DIM_ATTR_KEY) {
return SetDimAttributeNode.create();
} else if (name == RRuntime.DIMNAMES_ATTR_KEY) {
return SpecialAttributesFunctions.SetDimNamesAttributeNode.create();
return SetDimNamesAttributeNode.create();
} else if (name == RRuntime.ROWNAMES_ATTR_KEY) {
return SpecialAttributesFunctions.SetRowNamesAttributeNode.create();
return SetRowNamesAttributeNode.create();
} else if (name == RRuntime.CLASS_ATTR_KEY) {
throw RInternalError.unimplemented("The \"class\" attribute should be set using a separate method");
} else {
......@@ -766,6 +766,8 @@ public final class SpecialAttributesFunctions {
public abstract static class SetRowNamesAttributeNode extends SetSpecialAttributeNode {
private final ConditionProfile nullRowNamesProfile = ConditionProfile.createBinaryProfile();
protected SetRowNamesAttributeNode() {
super(RRuntime.ROWNAMES_ATTR_KEY);
}
......@@ -774,6 +776,20 @@ public final class SpecialAttributesFunctions {
return SpecialAttributesFunctionsFactory.SetRowNamesAttributeNodeGen.create();
}
public void setRowNames(RAbstractContainer x, RAbstractVector rowNames) {
if (nullRowNamesProfile.profile(rowNames == null)) {
execute(x, RNull.instance);
} else {
execute(x, rowNames);
}
}
@Specialization(insertBefore = "setAttrInAttributable")
protected void resetRowNames(RAbstractContainer x, @SuppressWarnings("unused") RNull rnull,
@Cached("create()") RemoveRowNamesAttributeNode removeRowNamesAttrNode) {
removeRowNamesAttrNode.execute(x);
}
@Specialization(insertBefore = "setAttrInAttributable")
protected void setRowNamesInContainer(RAbstractContainer x, RAbstractVector rowNames, @Cached("createClassProfile()") ValueProfile contClassProfile) {
RAbstractContainer xProfiled = contClassProfile.profile(x);
......@@ -809,12 +825,28 @@ public final class SpecialAttributesFunctions {
return SpecialAttributesFunctionsFactory.GetRowNamesAttributeNodeGen.create();
}
public Object getRowNames(RAbstractContainer x) {
return execute(x);
}
@Specialization(insertBefore = "getAttrFromAttributable")
protected Object getScalarVectorRowNames(@SuppressWarnings("unused") RScalarVector x) {
return RNull.instance;
}
@Specialization(insertBefore = "getAttrFromAttributable")
protected Object getScalarVectorRowNames(@SuppressWarnings("unused") RSequence x) {
return RNull.class;
}
@Specialization(insertBefore = "getAttrFromAttributable")
protected Object getVectorRowNames(RVector<?> x,
protected Object getVectorRowNames(RAbstractVector x,
@Cached("create()") BranchProfile attrNullProfile,
@Cached("createBinaryProfile()") ConditionProfile attrStorageProfile,
@Cached("createClassProfile()") ValueProfile xTypeProfile) {
return super.getAttrFromAttributable(x, attrNullProfile, attrStorageProfile, xTypeProfile);
@Cached("createClassProfile()") ValueProfile xTypeProfile,
@Cached("createBinaryProfile()") ConditionProfile nullRowNamesProfile) {
Object res = super.getAttrFromAttributable(x, attrNullProfile, attrStorageProfile, xTypeProfile);
return nullRowNamesProfile.profile(res == null) ? RNull.instance : res;
}
@Specialization(insertBefore = "getAttrFromAttributable")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment