From 59e08eed83cbdcf8030c49ebe98c2cf76000c4f3 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Wed, 28 Jun 2017 14:46:58 +0200
Subject: [PATCH] refactor mapply to proper loops

---
 .../truffle/r/nodes/builtin/base/Mapply.java  | 128 +++++++++++-------
 1 file changed, 78 insertions(+), 50 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java
index e649410a89..7c53ab3e42 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java
@@ -28,10 +28,13 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
 
 import com.oracle.truffle.api.CompilerAsserts;
+import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.frame.VirtualFrame;
+import com.oracle.truffle.api.nodes.ExplodeLoop;
 import com.oracle.truffle.api.nodes.Node;
+import com.oracle.truffle.api.profiles.BranchProfile;
 import com.oracle.truffle.r.nodes.access.WriteVariableNode;
 import com.oracle.truffle.r.nodes.access.WriteVariableNode.Mode;
 import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode;
@@ -117,7 +120,10 @@ public abstract class Mapply extends RBuiltinNode.Arg3 {
             }
         }
 
-        @Child private GetNamesAttributeNode getNames = GetNamesAttributeNode.create();
+        @Child private GetNamesAttributeNode getNamesDots = GetNamesAttributeNode.create();
+        @Child private GetNamesAttributeNode getNamesMoreArgs = GetNamesAttributeNode.create();
+
+        private final BranchProfile nonPerfectMatch = BranchProfile.create();
 
         public abstract Object[] execute(VirtualFrame frame, RAbstractListVector dots, RFunction function, RAbstractListVector additionalArguments);
 
@@ -125,18 +131,46 @@ public abstract class Mapply extends RBuiltinNode.Arg3 {
             return extractNode.apply(frame, dots.getDataAt(listIndex), new Object[]{i % lengths[listIndex] + 1}, RLogical.TRUE, RLogical.TRUE);
         }
 
-        @SuppressWarnings("unused")
-        @Specialization(limit = "5", guards = {"dots.getLength() == cachedDots.getLength()",
-                        "moreArgs.getLength() == cachedMoreArgs.getLength()", "sameNames(dots, cachedDots)",
-                        "sameNames(moreArgs, cachedMoreArgs)"})
+        @Specialization(limit = "5", guards = {"dots.getLength() == dotsLength", "moreArgs.getLength() == moreArgsLength",
+                        "sameNames(dots, cachedDotsNames)", "sameNames(moreArgs, cachedMoreArgsNames)"})
         protected Object[] cachedMApply(VirtualFrame frame, RAbstractListVector dots, RFunction function, RAbstractListVector moreArgs,
-                        @Cached("dots") RAbstractListVector cachedDots,
-                        @Cached("moreArgs") RAbstractListVector cachedMoreArgs,
-                        @Cached("createElementNodeArray(cachedDots, cachedMoreArgs)") ElementNode[] cachedElementNodeArray,
+                        @Cached("dots.getLength()") int dotsLength,
+                        @Cached("moreArgs.getLength()") int moreArgsLength,
+                        @SuppressWarnings("unused") @Cached("extractNames(dots)") String[] cachedDotsNames,
+                        @SuppressWarnings("unused") @Cached("extractNames(moreArgs)") String[] cachedMoreArgsNames,
+                        @Cached("createElementNodeArray(dotsLength, moreArgsLength, cachedDotsNames, cachedMoreArgsNames)") ElementNode[] cachedElementNodeArray,
                         @Cached("createCallNode(cachedElementNodeArray)") RCallNode callNode) {
-            int dotsLength = dots.getLength();
-            int moreArgsLength = moreArgs.getLength();
             int[] lengths = new int[dotsLength];
+            int maxLength = getDotsLengths(frame, dots, dotsLength, cachedElementNodeArray, lengths);
+            storeAdditionalArguments(frame, moreArgs, dotsLength, moreArgsLength, cachedElementNodeArray);
+            Object[] result = new Object[maxLength];
+            for (int i = 0; i < maxLength; i++) {
+                /* Evaluate and store the arguments */
+                prepareElements(frame, dots, dotsLength, cachedElementNodeArray, lengths, i);
+                /* Now call the function */
+                result[i] = callNode.execute(frame, function);
+            }
+            return result;
+        }
+
+        @ExplodeLoop
+        private static void prepareElements(VirtualFrame frame, RAbstractListVector dots, int dotsLength, ElementNode[] cachedElementNodeArray, int[] lengths, int i) {
+            for (int listIndex = 0; listIndex < dotsLength; listIndex++) {
+                Object vecElement = getVecElement(frame, dots, i, listIndex, lengths, cachedElementNodeArray[listIndex].extractNode);
+                cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame, vecElement);
+            }
+        }
+
+        @ExplodeLoop
+        private static void storeAdditionalArguments(VirtualFrame frame, RAbstractListVector moreArgs, int dotsLength, int moreArgsLength, ElementNode[] cachedElementNodeArray) {
+            for (int listIndex = dotsLength; listIndex < dotsLength + moreArgsLength; listIndex++) {
+                // store additional arguments
+                cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame, moreArgs.getDataAt(listIndex - dotsLength));
+            }
+        }
+
+        @ExplodeLoop
+        private static int getDotsLengths(VirtualFrame frame, RAbstractListVector dots, int dotsLength, ElementNode[] cachedElementNodeArray, int[] lengths) {
             int maxLength = -1;
             for (int i = 0; i < dotsLength; i++) {
                 int length = cachedElementNodeArray[i].lengthNode.executeInt(frame, dots.getDataAt(i));
@@ -145,22 +179,13 @@ public abstract class Mapply extends RBuiltinNode.Arg3 {
                 }
                 lengths[i] = length;
             }
-            for (int listIndex = dotsLength; listIndex < dotsLength + moreArgsLength; listIndex++) {
-                // store additional arguments
-                cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame,
-                                moreArgs.getDataAt(listIndex - dotsLength));
-            }
-            Object[] result = new Object[maxLength];
-            for (int i = 0; i < maxLength; i++) {
-                /* Evaluate and store the arguments */
-                for (int listIndex = 0; listIndex < dotsLength; listIndex++) {
-                    Object vecElement = getVecElement(frame, dots, i, listIndex, lengths, cachedElementNodeArray[listIndex].extractNode);
-                    cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame, vecElement);
-                }
-                /* Now call the function */
-                result[i] = callNode.execute(frame, function);
-            }
-            return result;
+            return maxLength;
+        }
+
+        protected static String[] extractNames(RAbstractListVector list) {
+            CompilerAsserts.neverPartOfCompilation();
+            RStringVector names = list.getNames();
+            return names == null ? null : names.getDataCopy();
         }
 
         @Specialization(replaces = "cachedMApply")
@@ -181,13 +206,13 @@ public abstract class Mapply extends RBuiltinNode.Arg3 {
             }
             Object[] values = new Object[dotsLength + moreArgsLength];
             String[] names = new String[dotsLength + moreArgsLength];
-            RStringVector dotsNames = getNames.getNames(dots);
+            RStringVector dotsNames = getNamesDots.getNames(dots);
             if (dotsNames != null) {
                 for (int listIndex = 0; listIndex < dotsLength; listIndex++) {
                     names[listIndex] = dotsNames.getDataAt(listIndex).isEmpty() ? null : dotsNames.getDataAt(listIndex);
                 }
             }
-            RStringVector moreArgsNames = getNames.getNames(moreArgs);
+            RStringVector moreArgsNames = getNamesMoreArgs.getNames(moreArgs);
             for (int listIndex = dotsLength; listIndex < dotsLength + moreArgsLength; listIndex++) {
                 values[listIndex] = moreArgs.getDataAt(listIndex - dotsLength);
                 names[listIndex] = moreArgsNames == null ? null : (moreArgsNames.getDataAt(listIndex - dotsLength).isEmpty() ? null : moreArgsNames.getDataAt(listIndex - dotsLength));
@@ -221,17 +246,15 @@ public abstract class Mapply extends RBuiltinNode.Arg3 {
             return RCallNode.createCall(Lapply.createCallSourceSection(), null, ArgumentsSignature.get(names), syntaxNodes);
         }
 
-        protected ElementNode[] createElementNodeArray(RAbstractListVector dots, RAbstractListVector moreArgs) {
-            int length = dots.getLength() + moreArgs.getLength();
+        protected ElementNode[] createElementNodeArray(int dotsLength, int moreArgsLength, String[] cachedDotsNames, String[] cachedMoreArgsNames) {
+            int length = dotsLength + moreArgsLength;
             ElementNode[] elementNodes = new ElementNode[length];
-            RStringVector dotsNames = getNames.getNames(dots);
-            for (int i = 0; i < dots.getLength(); i++) {
-                elementNodes[i] = insert(new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1), dotsNames == null ? null : (dotsNames.getDataAt(i).isEmpty() ? null : dotsNames.getDataAt(i))));
+            for (int i = 0; i < dotsLength; i++) {
+                elementNodes[i] = insert(new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1), cachedDotsNames == null ? null : (cachedDotsNames[i].isEmpty() ? null : cachedDotsNames[i])));
             }
-            RStringVector moreArgsNames = getNames.getNames(moreArgs);
-            for (int i = dots.getLength(); i < dots.getLength() + moreArgs.getLength(); i++) {
-                elementNodes[i] = insert(new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1),
-                                moreArgsNames == null ? null : moreArgsNames.getDataAt(i - dots.getLength()).isEmpty() ? null : moreArgsNames.getDataAt(i - dots.getLength())));
+            for (int i = 0; i < moreArgsLength; i++) {
+                elementNodes[i + dotsLength] = insert(
+                                new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1 + dotsLength), cachedMoreArgsNames == null ? null : cachedMoreArgsNames[i].isEmpty() ? null : cachedMoreArgsNames[i]));
             }
             return elementNodes;
         }
@@ -240,29 +263,34 @@ public abstract class Mapply extends RBuiltinNode.Arg3 {
             return ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, false);
         }
 
-        protected boolean sameNames(RAbstractListVector list, RAbstractListVector cachedList) {
-            RStringVector listNames = getNames.getNames(list);
-            RStringVector cachedListNames = getNames.getNames(cachedList);
-            if (listNames == null && cachedListNames == null) {
+        protected boolean sameNames(RAbstractListVector list, String[] cachedNames) {
+            RStringVector listNames = getNamesDots.getNames(list);
+            if (listNames == null && cachedNames == null) {
                 return true;
-            } else if (listNames == null || cachedListNames == null) {
+            } else if (listNames == null || cachedNames == null) {
                 return false;
             } else {
-
-                for (int i = 0; i < list.getLength(); i++) {
+                for (int i = 0; i < cachedNames.length; i++) {
                     String name = listNames.getDataAt(i);
-                    String cachedName = cachedListNames.getDataAt(i);
-
+                    String cachedName = cachedNames[i];
                     if (name == cachedName) {
                         continue;
-                    } else if (name == null || cachedName == null) {
-                        return false;
-                    } else if (!(name.equals(cachedName))) {
-                        return false;
+                    } else {
+                        nonPerfectMatch.enter();
+                        if (name == null || cachedName == null) {
+                            return false;
+                        } else if (!equals(name, cachedName)) {
+                            return false;
+                        }
                     }
                 }
                 return true;
             }
         }
+
+        @TruffleBoundary
+        private static boolean equals(String name, String cachedName) {
+            return name.equals(cachedName);
+        }
     }
 }
-- 
GitLab