From a0955c2bf0bac1f96a01de56866c47ab27c7df8b Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Tue, 2 Jan 2018 18:52:47 +0100
Subject: [PATCH] Fix Rf_getAttrib works with string vectors

GetAttributeNode updated to handle special cases of names and dimnames and
reused in both attr builtin and Rf_getAttrib RFF function.
---
 .../ffi/impl/common/JavaUpCallsRFFIImpl.java  | 15 +-------
 .../ffi/impl/nodes/AttributesAccessNodes.java | 35 ++++++++++++++++++
 .../truffle/r/ffi/impl/nodes/CoerceNodes.java |  1 -
 .../r/ffi/impl/upcalls/StdUpCallsRFFI.java    |  2 +
 .../truffle/r/nodes/builtin/base/Attr.java    | 18 +--------
 .../r/nodes/attributes/GetAttributeNode.java  | 37 ++++++++++++++++++-
 .../packages/testrffi/testrffi/R/testrffi.R   |  4 ++
 .../packages/testrffi/testrffi/src/testrffi.c |  4 ++
 .../packages/testrffi/testrffi/src/testrffi.h |  2 +
 .../testrffi/testrffi/tests/simpleTests.R     |  2 +
 10 files changed, 86 insertions(+), 34 deletions(-)

diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java
index c2b3d6e951..eeb5dd0ceb 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/common/JavaUpCallsRFFIImpl.java
@@ -287,21 +287,8 @@ public abstract class JavaUpCallsRFFIImpl implements UpCallsRFFI {
     }
 
     @Override
-    @TruffleBoundary
     public Object Rf_getAttrib(Object obj, Object name) {
-        Object result = RNull.instance;
-        if (obj instanceof RAttributable) {
-            RAttributable attrObj = (RAttributable) obj;
-            DynamicObject attrs = attrObj.getAttributes();
-            if (attrs != null) {
-                String nameAsString = Utils.intern(((RSymbol) name).getName());
-                Object attr = attrs.get(nameAsString);
-                if (attr != null) {
-                    result = attr;
-                }
-            }
-        }
-        return result;
+        throw implementedAsNode();
     }
 
     @Override
diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java
index 7b6ef2821f..7f3ba16ac6 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AttributesAccessNodes.java
@@ -22,17 +22,24 @@
  */
 package com.oracle.truffle.r.ffi.impl.nodes;
 
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.singleElement;
+import static com.oracle.truffle.r.nodes.builtin.casts.fluent.CastNodeBuilder.newCastBuilder;
+
 import com.oracle.truffle.api.CompilerDirectives;
 import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Fallback;
 import com.oracle.truffle.api.dsl.Specialization;
+import com.oracle.truffle.api.profiles.ConditionProfile;
 import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodesFactory.ATTRIBNodeGen;
 import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodesFactory.CopyMostAttribNodeGen;
 import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodesFactory.TAGNodeGen;
 import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode;
+import com.oracle.truffle.r.nodes.attributes.GetAttributeNode;
 import com.oracle.truffle.r.nodes.attributes.GetAttributesNode;
 import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
+import com.oracle.truffle.r.nodes.function.opt.UpdateShareableChildValueNode;
+import com.oracle.truffle.r.nodes.unary.CastNode;
 import com.oracle.truffle.r.runtime.ArgumentsSignature;
 import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RError.Message;
@@ -46,9 +53,37 @@ import com.oracle.truffle.r.runtime.data.RList;
 import com.oracle.truffle.r.runtime.data.RNull;
 import com.oracle.truffle.r.runtime.data.RPairList;
 import com.oracle.truffle.r.runtime.data.RStringVector;
+import com.oracle.truffle.r.runtime.data.RSymbol;
 
 public final class AttributesAccessNodes {
 
+    public static final class GetAttrib extends FFIUpCallNode.Arg2 {
+        @Child private GetAttributeNode getAttributeNode = GetAttributeNode.create();
+        @Child private UpdateShareableChildValueNode sharedAttrUpdate = UpdateShareableChildValueNode.create();
+        @Child private CastNode castStringNode;
+        private final ConditionProfile nameIsSymbolProfile = ConditionProfile.createBinaryProfile();
+
+        @Override
+        public Object executeObject(Object source, Object nameObj) {
+            String name;
+            if (nameIsSymbolProfile.profile(nameObj instanceof RSymbol)) {
+                name = ((RSymbol) nameObj).getName();
+            } else {
+                name = castToString(nameObj);
+            }
+            Object result = getAttributeNode.execute(source, name);
+            return result == null ? RNull.instance : sharedAttrUpdate.updateState(source, result);
+        }
+
+        private String castToString(Object name) {
+            if (castStringNode == null) {
+                CompilerDirectives.transferToInterpreterAndInvalidate();
+                castStringNode = insert(newCastBuilder().asStringVector().mustBe(singleElement()).findFirst().buildCastNode());
+            }
+            return (String) castStringNode.doCast(name);
+        }
+    }
+
     public abstract static class ATTRIB extends FFIUpCallNode.Arg1 {
         @Child private GetAttributesNode getAttributesNode;
 
diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java
index 63bbc904a2..b44ce88447 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/CoerceNodes.java
@@ -123,7 +123,6 @@ public final class CoerceNodes {
     public abstract static class AsCharacterFactor extends FFIUpCallNode.Arg1 {
 
         @Child private InheritsCheckNode inheritsFactorNode = InheritsCheckNode.createFactor();
-        @Child private GetAttributeNode getAttributeNode = GetAttributeNode.create();
         @Child private RFactorNodes.GetLevels getLevels = RFactorNodes.GetLevels.create();
 
         @Specialization
diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java
index cfc4c80acb..90dba617d4 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/upcalls/StdUpCallsRFFI.java
@@ -28,6 +28,7 @@ import com.oracle.truffle.r.ffi.impl.nodes.AsLogicalNode;
 import com.oracle.truffle.r.ffi.impl.nodes.AsRealNode;
 import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.ATTRIB;
 import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.CopyMostAttrib;
+import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.GetAttrib;
 import com.oracle.truffle.r.ffi.impl.nodes.AttributesAccessNodes.TAG;
 import com.oracle.truffle.r.ffi.impl.nodes.CoerceNodes.CoerceVectorNode;
 import com.oracle.truffle.r.ffi.impl.nodes.CoerceNodes.VectorToPairListNode;
@@ -122,6 +123,7 @@ public interface StdUpCallsRFFI {
     @RFFIUpCallNode(ATTRIB.class)
     Object ATTRIB(Object obj);
 
+    @RFFIUpCallNode(GetAttrib.class)
     Object Rf_getAttrib(Object obj, Object name);
 
     void Rf_setAttrib(Object obj, Object name, Object val);
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java
index 3e2489c23c..7d24d0a540 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java
@@ -110,23 +110,11 @@ public abstract class Attr extends RBuiltinNode.Arg3 {
         return container;
     }
 
-    @Specialization(guards = "!isRowNamesAttr(name)")
+    @Specialization
     protected Object attr(RAbstractContainer container, String name, boolean exact) {
         return attrRA(container, intern.execute(name), exact);
     }
 
-    @Specialization(guards = "isRowNamesAttr(name)")
-    protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact,
-                    @Cached("create()") GetRowNamesAttributeNode getRowNamesNode) {
-        // TODO: if exact == false, check for partial match (there is an ignored tests for it)
-        DynamicObject attributes = container.getAttributes();
-        if (attributes == null) {
-            return RNull.instance;
-        } else {
-            return GetAttributesNode.getFullRowNames(getRowNamesNode.getRowNames(container));
-        }
-    }
-
     /**
      * All other, non-performance centric, {@link RAttributable} types.
      */
@@ -141,8 +129,4 @@ public abstract class Attr extends RBuiltinNode.Arg3 {
             throw RError.nyi(this, "object cannot be attributed");
         }
     }
-
-    protected static boolean isRowNamesAttr(String name) {
-        return name.equals(RRuntime.ROWNAMES_ATTR_KEY);
-    }
 }
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java
index dd749a55ff..ec04358a1e 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/GetAttributeNode.java
@@ -32,8 +32,12 @@ import com.oracle.truffle.api.object.Shape;
 import com.oracle.truffle.api.profiles.BranchProfile;
 import com.oracle.truffle.api.profiles.ConditionProfile;
 import com.oracle.truffle.api.profiles.ValueProfile;
+import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
+import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetRowNamesAttributeNode;
+import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.data.RAttributable;
 import com.oracle.truffle.r.runtime.data.RAttributeStorage;
+import com.oracle.truffle.r.runtime.data.RNull;
 
 /**
  * This node is responsible for retrieving a value from an arbitrary attribute. It accepts both
@@ -54,8 +58,25 @@ public abstract class GetAttributeNode extends AttributeAccessNode {
 
     public abstract Object execute(Object attrs, String name);
 
+    @Specialization
+    protected RNull attr(RNull container, @SuppressWarnings("unused") String name) {
+        return container;
+    }
+
+    @Specialization(guards = "isRowNamesAttr(name)")
+    protected Object getRowNames(DynamicObject attrs, @SuppressWarnings("unused") String name,
+                    @Cached("create()") GetRowNamesAttributeNode getRowNamesNode) {
+        return GetAttributesNode.getFullRowNames(getRowNamesNode.execute(attrs));
+    }
+
+    @Specialization(guards = "isNamesAttr(name)")
+    protected Object getNames(DynamicObject attrs, @SuppressWarnings("unused") String name,
+                    @Cached("create()") GetNamesAttributeNode getNamesAttributeNode) {
+        return getNamesAttributeNode.execute(attrs);
+    }
+
     @Specialization(limit = "3", //
-                    guards = {"cachedName.equals(name)", "shapeCheck(shape, attrs)"}, //
+                    guards = {"!isSpecialAttribute(name)", "cachedName.equals(name)", "shapeCheck(shape, attrs)"}, //
                     assumptions = {"shape.getValidAssumption()"})
     protected Object getAttrCached(DynamicObject attrs, @SuppressWarnings("unused") String name,
                     @SuppressWarnings("unused") @Cached("name") String cachedName,
@@ -65,7 +86,7 @@ public abstract class GetAttributeNode extends AttributeAccessNode {
     }
 
     @TruffleBoundary
-    @Specialization(replaces = {"getAttrCached"})
+    @Specialization(replaces = "getAttrCached", guards = "!isSpecialAttribute(name)")
     protected Object getAttrFallback(DynamicObject attrs, String name) {
         return attrs.get(name);
     }
@@ -95,4 +116,16 @@ public abstract class GetAttributeNode extends AttributeAccessNode {
 
         return recursive.execute(attributes, name);
     }
+
+    protected static boolean isRowNamesAttr(String name) {
+        return name.equals(RRuntime.ROWNAMES_ATTR_KEY);
+    }
+
+    protected static boolean isNamesAttr(String name) {
+        return name.equals(RRuntime.NAMES_ATTR_KEY);
+    }
+
+    protected static boolean isSpecialAttribute(String name) {
+        return isRowNamesAttr(name) || isNamesAttr(name);
+    }
 }
diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R
index 214488fecf..00e256bf93 100644
--- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R
+++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/R/testrffi.R
@@ -165,6 +165,10 @@ rffi.ATTRIB <- function(x) {
     .Call('test_ATTRIB', x);
 }
 
+rffi.getAttrib <- function(source, name) {
+    .Call('test_getAttrib', source, name);
+}
+
 rffi.getStringNA <- function() {
     .Call("test_stringNA")
 }
diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c
index 3e6edb7ee6..7ab9d8c4be 100644
--- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c
+++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.c
@@ -347,6 +347,10 @@ SEXP test_ATTRIB(SEXP x) {
     return ATTRIB(x);
 }
 
+SEXP test_getAttrib(SEXP source, SEXP name) {
+    return Rf_getAttrib(source, name);
+}
+
 SEXP test_stringNA(void) {
     SEXP x = allocVector(STRSXP, 1);
     SET_STRING_ELT(x, 0, NA_STRING);
diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h
index 2fbab660b9..a943132199 100644
--- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h
+++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/src/testrffi.h
@@ -92,6 +92,8 @@ extern SEXP test_coerceVector(SEXP x, SEXP mode);
 
 extern SEXP test_ATTRIB(SEXP);
 
+extern SEXP test_getAttrib(SEXP,SEXP);
+
 extern SEXP test_stringNA(void);
 
 extern SEXP test_captureDotsWithSingleElement(SEXP env);
diff --git a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R
index e1fea9c3cc..f74c7dce9e 100644
--- a/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R
+++ b/com.oracle.truffle.r.test.native/packages/testrffi/testrffi/tests/simpleTests.R
@@ -30,6 +30,8 @@ x <- list(1)
 attr(x, 'myattr') <- 'hello';
 attrs <- rffi.ATTRIB(x)
 stopifnot(attrs[[1]] == 'hello')
+attr <- rffi.getAttrib(x, 'myattr')
+stopifnot(attr == 'hello')
 
 # loess invokes loess_raw native function passing in string value as argument and that is what we test here.
 loess(dist ~ speed, cars);
-- 
GitLab