From f0aca5a5fd12da364d8ceea0a70bcd068b99f670 Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Thu, 4 Aug 2016 11:03:03 +0200
Subject: [PATCH] Attr: cast pipelines + exact argument

---
 .../truffle/r/nodes/builtin/base/Attr.java    | 89 ++++++++-----------
 .../truffle/r/test/ExpectedTestOutput.test    |  8 ++
 .../r/test/builtins/TestBuiltin_attr.java     |  6 ++
 3 files changed, 50 insertions(+), 53 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java
index ff4f6c27ec..87826e3771 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Attr.java
@@ -22,6 +22,9 @@
  */
 package com.oracle.truffle.r.nodes.builtin.base;
 
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.singleElement;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
 
@@ -30,10 +33,11 @@ import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
 import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.Fallback;
 import com.oracle.truffle.api.dsl.Specialization;
-import com.oracle.truffle.api.profiles.BranchProfile;
 import com.oracle.truffle.api.profiles.ConditionProfile;
+import com.oracle.truffle.r.nodes.builtin.CastBuilder;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
 import com.oracle.truffle.r.runtime.RError;
+import com.oracle.truffle.r.runtime.RError.Message;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
 import com.oracle.truffle.r.runtime.data.RAttributable;
@@ -42,8 +46,8 @@ import com.oracle.truffle.r.runtime.data.RAttributes;
 import com.oracle.truffle.r.runtime.data.RAttributes.RAttribute;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
 import com.oracle.truffle.r.runtime.data.RInteger;
+import com.oracle.truffle.r.runtime.data.RMissing;
 import com.oracle.truffle.r.runtime.data.RNull;
-import com.oracle.truffle.r.runtime.data.RStringVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractContainer;
 import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
@@ -52,12 +56,23 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
 public abstract class Attr extends RBuiltinNode {
 
     private final ConditionProfile searchPartialProfile = ConditionProfile.createBinaryProfile();
-    private final BranchProfile errorProfile = BranchProfile.create();
     private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
 
     @CompilationFinal private String cachedName = "";
     @CompilationFinal private String cachedInternedName = "";
 
+    @Override
+    public Object[] getDefaultParameterValues() {
+        return new Object[]{RMissing.instance, RMissing.instance, RRuntime.asLogical(false)};
+    }
+
+    @Override
+    protected void createCasts(CastBuilder casts) {
+        casts.arg("x").mustBe(RAttributable.class, Message.UNIMPLEMENTED_ARGUMENT_TYPE);
+        casts.arg("which").mustBe(stringValue(), Message.MUST_BE_CHARACTER, "which").asStringVector().mustBe(singleElement(), RError.Message.EXACTLY_ONE_WHICH).findFirst();
+        casts.arg("exact").asLogicalVector().findFirst().map(toBoolean());
+    }
+
     private String intern(String name) {
         if (cachedName == null) {
             // unoptimized case
@@ -95,41 +110,32 @@ public abstract class Attr extends RBuiltinNode {
         return val;
     }
 
-    private Object attrRA(RAttributable attributable, String name) {
+    private Object attrRA(RAttributable attributable, String name, boolean exact) {
         RAttributes attributes = attributable.getAttributes();
         if (attributes == null) {
             return RNull.instance;
         } else {
             Object result = attributes.get(name);
-            if (searchPartialProfile.profile(result == null)) {
+            if (searchPartialProfile.profile(!exact && result == null)) {
                 return searchKeyPartial(attributes, name);
             }
-            return result;
+            return result == null ? RNull.instance : result;
         }
     }
 
     @Specialization
-    protected RNull attr(RNull container, @SuppressWarnings("unused") String name) {
+    protected RNull attr(RNull container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact) {
         return container;
     }
 
     @Specialization(guards = "!isRowNamesAttr(name)")
-    protected Object attr(RAbstractContainer container, String name) {
-        return attrRA(container, intern(name));
-    }
-
-    public static Object getFullRowNames(Object a) {
-        if (a == RNull.instance) {
-            return RNull.instance;
-        } else {
-            RAbstractVector rowNames = (RAbstractVector) a;
-            return rowNames.getElementClass() == RInteger.class && rowNames.getLength() == 2 && RRuntime.isNA(((RAbstractIntVector) rowNames).getDataAt(0)) ? RDataFactory.createIntSequence(1, 1,
-                            Math.abs(((RAbstractIntVector) rowNames).getDataAt(1))) : a;
-        }
+    protected Object attr(RAbstractContainer container, String name, boolean exact) {
+        return attrRA(container, intern(name), exact);
     }
 
     @Specialization(guards = "isRowNamesAttr(name)")
-    protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name) {
+    protected Object attrRowNames(RAbstractContainer container, @SuppressWarnings("unused") String name, @SuppressWarnings("unused") boolean exact) {
+        // TODO: if exact == false, check for partial match (there is an ignored tests for it)
         RAttributes attributes = container.getAttributes();
         if (attributes == null) {
             return RNull.instance;
@@ -138,49 +144,26 @@ public abstract class Attr extends RBuiltinNode {
         }
     }
 
-    @Specialization(guards = {"exactlyOne(name)", "isRowNamesAttr(name)"})
-    protected Object attrRowNames(RAbstractContainer container, RStringVector name) {
-        return attrRowNames(container, name.getDataAt(0));
-    }
-
-    @Specialization(guards = {"exactlyOne(name)", "!isRowNamesAttr(name)"})
-    protected Object attr(RAbstractContainer container, RStringVector name) {
-        return attr(container, name.getDataAt(0));
-    }
-
-    @SuppressWarnings("unused")
-    @Specialization(guards = "!exactlyOne(name)")
-    protected Object attrEmtpyName(RAbstractContainer container, RStringVector name) {
-        throw RError.error(this, RError.Message.EXACTLY_ONE_WHICH);
-    }
-
     /**
      * All other, non-performance centric, {@link RAttributable} types.
      */
     @Fallback
     @TruffleBoundary
-    protected Object attr(Object object, Object name) {
-        String sname = RRuntime.asString(name);
-        if (sname == null) {
-            throw RError.error(this, RError.Message.MUST_BE_CHARACTER, "which");
-        }
-        if (object instanceof RAttributable) {
-            return attrRA((RAttributable) object, intern(sname));
+    protected Object attr(RAttributable object, String name, boolean exact) {
+        return attrRA(object, intern(name), exact);
+    }
+
+    public static Object getFullRowNames(Object a) {
+        if (a == RNull.instance) {
+            return RNull.instance;
         } else {
-            errorProfile.enter();
-            throw RError.nyi(this, "object cannot be attributed");
+            RAbstractVector rowNames = (RAbstractVector) a;
+            return rowNames.getElementClass() == RInteger.class && rowNames.getLength() == 2 && RRuntime.isNA(((RAbstractIntVector) rowNames).getDataAt(0)) ? RDataFactory.createIntSequence(1, 1,
+                            Math.abs(((RAbstractIntVector) rowNames).getDataAt(1))) : a;
         }
     }
 
     protected static boolean isRowNamesAttr(String name) {
         return name.equals(RRuntime.ROWNAMES_ATTR_KEY);
     }
-
-    protected static boolean isRowNamesAttr(RStringVector name) {
-        return isRowNamesAttr(name.getDataAt(0));
-    }
-
-    protected static boolean exactlyOne(RStringVector name) {
-        return name.getLength() == 1;
-    }
 }
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
index b196871331..65317939d0 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
@@ -7784,6 +7784,14 @@ In attach(list(x = 42), pos = "string") : NAs introduced by coercion
 #detach('string')
 Error in detach("string") : invalid 'name' argument
 
+##com.oracle.truffle.r.test.builtins.TestBuiltin_attr.testExactMatch
+#x <- c(1, 3); attr(x, 'abc') <- 42; attr(x, 'ab', exact=TRUE)
+NULL
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_attr.testExactMatch
+#x <- c(1,2); attr(x, 'row.namess') <- 42; attr(x, 'row.names')
+[1] 42
+
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_attr.testattr1
 #argv <- list(structure(c(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1), .Dim = c(32L, 23L), .Dimnames = list(c('1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32'), c('(Intercept)', 'HairBrown', 'HairRed', 'HairBlond', 'EyeBlue', 'EyeHazel', 'EyeGreen', 'SexFemale', 'HairBrown:EyeBlue', 'HairRed:EyeBlue', 'HairBlond:EyeBlue', 'HairBrown:EyeHazel', 'HairRed:EyeHazel', 'HairBlond:EyeHazel', 'HairBrown:EyeGreen', 'HairRed:EyeGreen', 'HairBlond:EyeGreen', 'HairBrown:SexFemale', 'HairRed:SexFemale', 'HairBlond:SexFemale', 'EyeBlue:SexFemale', 'EyeHazel:SexFemale', 'EyeGreen:SexFemale')), assign = c(0L, 1L, 1L, 1L, 2L, 2L, 2L, 3L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 5L, 5L, 5L, 6L, 6L, 6L), contrasts = structure(list(Hair = 'contr.treatment',     Eye = 'contr.treatment', Sex = 'contr.treatment'), .Names = c('Hair', 'Eye', 'Sex'))), 'assign');attr(argv[[1]],argv[[2]]);
  [1] 0 1 1 1 2 2 2 3 4 4 4 4 4 4 4 4 4 5 5 5 6 6 6
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_attr.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_attr.java
index e5c6ae3818..47cd777263 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_attr.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_attr.java
@@ -237,4 +237,10 @@ public class TestBuiltin_attr extends TestBase {
     public void testattr45() {
         assertEval("argv <- list(c(35.2589338684655, 59.5005803666983, 12.4529321610302, 2.53579570262684, 10.370198579714, 42.0067149618146, 8.14319638132861, 34.0508943233725, 7.78517191057496, 26.9998965458032, 6.70435391953205, 3.62502215105156, 2.59277105754344, 14.4998960151485, 6.70435391953205, 5.8000097831969, 32.741875696675, 59.5015090627504, 13.5512565366133, 4.46460764999704, 9.62989278443572, 42.0073706103832, 8.86141045052292, 59.9511558158597, 7.22940551532861, 27.0003179651772, 7.29566488446303, 6.38233656214029, 2.40767880256155, 14.5001223322046, 7.29566488446303, 10.2116933242272), 'dim');attr(argv[[1]],argv[[2]]);");
     }
+
+    @Test
+    public void testExactMatch() {
+        assertEval("x <- c(1, 3); attr(x, 'abc') <- 42; attr(x, 'ab', exact=TRUE)");
+        assertEval(Ignored.Unimplemented, "x <- c(1,2); attr(x, 'row.namess') <- 42; attr(x, 'row.names')");
+    }
 }
-- 
GitLab