From a0955c2bf0bac1f96a01de56866c47ab27c7df8b Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Tue, 2 Jan 2018 18:52:47 +0100 Subject: [PATCH] Fix Rf_getAttrib works with string vectors GetAttributeNode updated to handle special cases of names and dimnames and reused in both attr builtin and Rf_getAttrib RFF function. --- .../ffi/impl/common/JavaUpCallsRFFIImpl.java | 15 +------- .../ffi/impl/nodes/AttributesAccessNodes.java | 35 ++++++++++++++++++ .../truffle/r/ffi/impl/nodes/CoerceNodes.java | 1 - .../r/ffi/impl/upcalls/StdUpCallsRFFI.java | 2 + .../truffle/r/nodes/builtin/base/Attr.java | 18 +-------- .../r/nodes/attributes/GetAttributeNode.java | 37 ++++++++++++++++++- .../packages/testrffi/testrffi/R/testrffi.R | 4 ++ .../packages/testrffi/testrffi/src/testrffi.c | 4 ++ .../packages/testrffi/testrffi/src/testrffi.h | 2 + .../testrffi/testrffi/tests/simpleTests.R | 2 + 10 files changed, 86 insertions(+), 34 deletions(-) diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java index c2b3d6e951..eeb5dd0ceb 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java @@ -287,21 +287,8 @@ public abstract class JavaUpCallsRFFIImpl implements UpCallsRFFI { } @Override - @TruffleBoundary public Object Rf_getAttrib(Object obj, Object name) { - Object result = RNull.instance; - if (obj instanceof RAttributable) { - RAttributable attrObj = (RAttributable) obj; - DynamicObject attrs = attrObj.getAttributes(); - if (attrs != null) { - String nameAsString = Utils.intern(((RSymbol) name).getName()); - Object attr = attrs.get(nameAsString); - if (attr != null) { - result = attr; - } - } - } - return result; + throw implementedAsNode(); } @Override diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java index 7b6ef2821f..7f3ba16ac6 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java @@ -22,17 +22,24 @@ */ package com.oracle.truffle.r.ffi.impl.nodes; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.singleElement; +import static com.oracle.truffle.r.nodes.builtin.casts.fluent.CastNodeBuilder.newCastBuilder; + import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodesFactory.ATTRIBNodeGen; import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodesFactory.CopyMostAttribNodeGen; import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodesFactory.TAGNodeGen; import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode; +import com.oracle.truffle.r.nodes.attributes.GetAttributeNode; import com.oracle.truffle.r.nodes.attributes.GetAttributesNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; +import com.oracle.truffle.r.nodes.function.opt.UpdateShareableChildValueNode; +import com.oracle.truffle.r.nodes.unary.CastNode; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; @@ -46,9 +53,37 @@ import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RPairList; import com.oracle.truffle.r.runtime.data.RStringVector; +import com.oracle.truffle.r.runtime.data.RSymbol; public final class AttributesAccessNodes { + public static final class GetAttrib extends FFIUpCallNode.Arg2 { + @Child private GetAttributeNode getAttributeNode = GetAttributeNode.create(); + @Child private UpdateShareableChildValueNode sharedAttrUpdate = UpdateShareableChildValueNode.create(); + @Child private CastNode castStringNode; + private final ConditionProfile nameIsSymbolProfile = ConditionProfile.createBinaryProfile(); + + @Override + public Object executeObject(Object source, Object nameObj) { + String name; + if (nameIsSymbolProfile.profile(nameObj instanceof RSymbol)) { + name = ((RSymbol) nameObj).getName(); + } else { + name = castToString(nameObj); + } + Object result = getAttributeNode.execute(source, name); + return result == null ? RNull.instance : sharedAttrUpdate.updateState(source, result); + } + + private String castToString(Object name) { + if (castStringNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + castStringNode = insert(newCastBuilder().asStringVector().mustBe(singleElement()).findFirst().buildCastNode()); + } + return (String) castStringNode.doCast(name); + } + } + public abstract static class ATTRIB extends FFIUpCallNode.Arg1 { @Child private GetAttributesNode getAttributesNode; diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java index 63bbc904a2..b44ce88447 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java @@ -123,7 +123,6 @@ public final class CoerceNodes { public abstract static class AsCharacterFactor extends FFIUpCallNode.Arg1 { @Child private InheritsCheckNode inheritsFactorNode = InheritsCheckNode.createFactor(); - @Child private GetAttributeNode getAttributeNode = GetAttributeNode.create(); @Child private RFactorNodes.GetLevels getLevels = RFactorNodes.GetLevels.create(); @Specialization diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java index cfc4c80acb..90dba617d4 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java @@ -28,6 +28,7 @@ import com.oracle.truffle.r.ffi.impl.nodes.AsLogicalNode; import com.oracle.truffle.r.ffi.impl.nodes.AsRealNode; import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.ATTRIB; import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.CopyMostAttrib; +import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.GetAttrib; import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.TAG; import com.oracle.truffle.r.ffi.impl.nodes.CoerceNodes.CoerceVectorNode; import com.oracle.truffle.r.ffi.impl.nodes.CoerceNodes.VectorToPairListNode; @@ -122,6 +123,7 @@ public interface StdUpCallsRFFI { @RFFIUpCallNode(ATTRIB.class) Object ATTRIB(Object obj); + @RFFIUpCallNode(GetAttrib.class) Object Rf_getAttrib(Object obj, Object name); void Rf_setAttrib(Object obj, Object name, Object val); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java index 3e2489c23c..7d24d0a540 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java @@ -110,23 +110,11 @@ public abstract class Attr extends RBuiltinNode.Arg3 { return container; } - @Specialization(guards = "!isRowNamesAttr(name)") + @Specialization protected Object attr(RAbstractContainer container, String name, boolean exact) { return attrRA(container, intern.execute(name), exact); } - @Specialization(guards = "isRowNamesAttr(name)") - protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact, - @Cached("create()") GetRowNamesAttributeNode getRowNamesNode) { - // TODO: if exact == false, check for partial match (there is an ignored tests for it) - DynamicObject attributes = container.getAttributes(); - if (attributes == null) { - return RNull.instance; - } else { - return GetAttributesNode.getFullRowNames(getRowNamesNode.getRowNames(container)); - } - } - /** * All other, non-performance centric, {@link RAttributable} types. */ @@ -141,8 +129,4 @@ public abstract class Attr extends RBuiltinNode.Arg3 { throw RError.nyi(this, "object cannot be attributed"); } } - - protected static boolean isRowNamesAttr(String name) { - return name.equals(RRuntime.ROWNAMES_ATTR_KEY); - } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java index dd749a55ff..ec04358a1e 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java @@ -32,8 +32,12 @@ import com.oracle.truffle.api.object.Shape; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ValueProfile; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetRowNamesAttributeNode; +import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RAttributable; import com.oracle.truffle.r.runtime.data.RAttributeStorage; +import com.oracle.truffle.r.runtime.data.RNull; /** * This node is responsible for retrieving a value from an arbitrary attribute. It accepts both @@ -54,8 +58,25 @@ public abstract class GetAttributeNode extends AttributeAccessNode { public abstract Object execute(Object attrs, String name); + @Specialization + protected RNull attr(RNull container, @SuppressWarnings("unused") String name) { + return container; + } + + @Specialization(guards = "isRowNamesAttr(name)") + protected Object getRowNames(DynamicObject attrs, @SuppressWarnings("unused") String name, + @Cached("create()") GetRowNamesAttributeNode getRowNamesNode) { + return GetAttributesNode.getFullRowNames(getRowNamesNode.execute(attrs)); + } + + @Specialization(guards = "isNamesAttr(name)") + protected Object getNames(DynamicObject attrs, @SuppressWarnings("unused") String name, + @Cached("create()") GetNamesAttributeNode getNamesAttributeNode) { + return getNamesAttributeNode.execute(attrs); + } + @Specialization(limit = "3", // - guards = {"cachedName.equals(name)", "shapeCheck(shape, attrs)"}, // + guards = {"!isSpecialAttribute(name)", "cachedName.equals(name)", "shapeCheck(shape, attrs)"}, // assumptions = {"shape.getValidAssumption()"}) protected Object getAttrCached(DynamicObject attrs, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") @Cached("name") String cachedName, @@ -65,7 +86,7 @@ public abstract class GetAttributeNode extends AttributeAccessNode { } @TruffleBoundary - @Specialization(replaces = {"getAttrCached"}) + @Specialization(replaces = "getAttrCached", guards = "!isSpecialAttribute(name)") protected Object getAttrFallback(DynamicObject attrs, String name) { return attrs.get(name); } @@ -95,4 +116,16 @@ public abstract class GetAttributeNode extends AttributeAccessNode { return recursive.execute(attributes, name); } + + protected static boolean isRowNamesAttr(String name) { + return name.equals(RRuntime.ROWNAMES_ATTR_KEY); + } + + protected static boolean isNamesAttr(String name) { + return name.equals(RRuntime.NAMES_ATTR_KEY); + } + + protected static boolean isSpecialAttribute(String name) { + return isRowNamesAttr(name) || isNamesAttr(name); + } } diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R index 214488fecf..00e256bf93 100644 --- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R +++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R @@ -165,6 +165,10 @@ rffi.ATTRIB <- function(x) { .Call('test_ATTRIB', x); } +rffi.getAttrib <- function(source, name) { + .Call('test_getAttrib', source, name); +} + rffi.getStringNA <- function() { .Call("test_stringNA") } diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c index 3e6edb7ee6..7ab9d8c4be 100644 --- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c +++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c @@ -347,6 +347,10 @@ SEXP test_ATTRIB(SEXP x) { return ATTRIB(x); } +SEXP test_getAttrib(SEXP source, SEXP name) { + return Rf_getAttrib(source, name); +} + SEXP test_stringNA(void) { SEXP x = allocVector(STRSXP, 1); SET_STRING_ELT(x, 0, NA_STRING); diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h index 2fbab660b9..a943132199 100644 --- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h +++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h @@ -92,6 +92,8 @@ extern SEXP test_coerceVector(SEXP x, SEXP mode); extern SEXP test_ATTRIB(SEXP); +extern SEXP test_getAttrib(SEXP,SEXP); + extern SEXP test_stringNA(void); extern SEXP test_captureDotsWithSingleElement(SEXP env); diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R index e1fea9c3cc..f74c7dce9e 100644 --- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R +++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R @@ -30,6 +30,8 @@ x <- list(1) attr(x, 'myattr') <- 'hello'; attrs <- rffi.ATTRIB(x) stopifnot(attrs[[1]] == 'hello') +attr <- rffi.getAttrib(x, 'myattr') +stopifnot(attr == 'hello') # loess invokes loess_raw native function passing in string value as argument and that is what we test here. loess(dist ~ speed, cars); -- GitLab