From 59e08eed83cbdcf8030c49ebe98c2cf76000c4f3 Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Wed, 28 Jun 2017 14:46:58 +0200 Subject: [PATCH] refactor mapply to proper loops --- .../truffle/r/nodes/builtin/base/Mapply.java | 128 +++++++++++------- 1 file changed, 78 insertions(+), 50 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java index e649410a89..7c53ab3e42 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mapply.java @@ -28,10 +28,13 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import com.oracle.truffle.api.CompilerAsserts; +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.r.nodes.access.WriteVariableNode; import com.oracle.truffle.r.nodes.access.WriteVariableNode.Mode; import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode; @@ -117,7 +120,10 @@ public abstract class Mapply extends RBuiltinNode.Arg3 { } } - @Child private GetNamesAttributeNode getNames = GetNamesAttributeNode.create(); + @Child private GetNamesAttributeNode getNamesDots = GetNamesAttributeNode.create(); + @Child private GetNamesAttributeNode getNamesMoreArgs = GetNamesAttributeNode.create(); + + private final BranchProfile nonPerfectMatch = BranchProfile.create(); public abstract Object[] execute(VirtualFrame frame, RAbstractListVector dots, RFunction function, RAbstractListVector additionalArguments); @@ -125,18 +131,46 @@ public abstract class Mapply extends RBuiltinNode.Arg3 { return extractNode.apply(frame, dots.getDataAt(listIndex), new Object[]{i % lengths[listIndex] + 1}, RLogical.TRUE, RLogical.TRUE); } - @SuppressWarnings("unused") - @Specialization(limit = "5", guards = {"dots.getLength() == cachedDots.getLength()", - "moreArgs.getLength() == cachedMoreArgs.getLength()", "sameNames(dots, cachedDots)", - "sameNames(moreArgs, cachedMoreArgs)"}) + @Specialization(limit = "5", guards = {"dots.getLength() == dotsLength", "moreArgs.getLength() == moreArgsLength", + "sameNames(dots, cachedDotsNames)", "sameNames(moreArgs, cachedMoreArgsNames)"}) protected Object[] cachedMApply(VirtualFrame frame, RAbstractListVector dots, RFunction function, RAbstractListVector moreArgs, - @Cached("dots") RAbstractListVector cachedDots, - @Cached("moreArgs") RAbstractListVector cachedMoreArgs, - @Cached("createElementNodeArray(cachedDots, cachedMoreArgs)") ElementNode[] cachedElementNodeArray, + @Cached("dots.getLength()") int dotsLength, + @Cached("moreArgs.getLength()") int moreArgsLength, + @SuppressWarnings("unused") @Cached("extractNames(dots)") String[] cachedDotsNames, + @SuppressWarnings("unused") @Cached("extractNames(moreArgs)") String[] cachedMoreArgsNames, + @Cached("createElementNodeArray(dotsLength, moreArgsLength, cachedDotsNames, cachedMoreArgsNames)") ElementNode[] cachedElementNodeArray, @Cached("createCallNode(cachedElementNodeArray)") RCallNode callNode) { - int dotsLength = dots.getLength(); - int moreArgsLength = moreArgs.getLength(); int[] lengths = new int[dotsLength]; + int maxLength = getDotsLengths(frame, dots, dotsLength, cachedElementNodeArray, lengths); + storeAdditionalArguments(frame, moreArgs, dotsLength, moreArgsLength, cachedElementNodeArray); + Object[] result = new Object[maxLength]; + for (int i = 0; i < maxLength; i++) { + /* Evaluate and store the arguments */ + prepareElements(frame, dots, dotsLength, cachedElementNodeArray, lengths, i); + /* Now call the function */ + result[i] = callNode.execute(frame, function); + } + return result; + } + + @ExplodeLoop + private static void prepareElements(VirtualFrame frame, RAbstractListVector dots, int dotsLength, ElementNode[] cachedElementNodeArray, int[] lengths, int i) { + for (int listIndex = 0; listIndex < dotsLength; listIndex++) { + Object vecElement = getVecElement(frame, dots, i, listIndex, lengths, cachedElementNodeArray[listIndex].extractNode); + cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame, vecElement); + } + } + + @ExplodeLoop + private static void storeAdditionalArguments(VirtualFrame frame, RAbstractListVector moreArgs, int dotsLength, int moreArgsLength, ElementNode[] cachedElementNodeArray) { + for (int listIndex = dotsLength; listIndex < dotsLength + moreArgsLength; listIndex++) { + // store additional arguments + cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame, moreArgs.getDataAt(listIndex - dotsLength)); + } + } + + @ExplodeLoop + private static int getDotsLengths(VirtualFrame frame, RAbstractListVector dots, int dotsLength, ElementNode[] cachedElementNodeArray, int[] lengths) { int maxLength = -1; for (int i = 0; i < dotsLength; i++) { int length = cachedElementNodeArray[i].lengthNode.executeInt(frame, dots.getDataAt(i)); @@ -145,22 +179,13 @@ public abstract class Mapply extends RBuiltinNode.Arg3 { } lengths[i] = length; } - for (int listIndex = dotsLength; listIndex < dotsLength + moreArgsLength; listIndex++) { - // store additional arguments - cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame, - moreArgs.getDataAt(listIndex - dotsLength)); - } - Object[] result = new Object[maxLength]; - for (int i = 0; i < maxLength; i++) { - /* Evaluate and store the arguments */ - for (int listIndex = 0; listIndex < dotsLength; listIndex++) { - Object vecElement = getVecElement(frame, dots, i, listIndex, lengths, cachedElementNodeArray[listIndex].extractNode); - cachedElementNodeArray[listIndex].writeVectorElementNode.execute(frame, vecElement); - } - /* Now call the function */ - result[i] = callNode.execute(frame, function); - } - return result; + return maxLength; + } + + protected static String[] extractNames(RAbstractListVector list) { + CompilerAsserts.neverPartOfCompilation(); + RStringVector names = list.getNames(); + return names == null ? null : names.getDataCopy(); } @Specialization(replaces = "cachedMApply") @@ -181,13 +206,13 @@ public abstract class Mapply extends RBuiltinNode.Arg3 { } Object[] values = new Object[dotsLength + moreArgsLength]; String[] names = new String[dotsLength + moreArgsLength]; - RStringVector dotsNames = getNames.getNames(dots); + RStringVector dotsNames = getNamesDots.getNames(dots); if (dotsNames != null) { for (int listIndex = 0; listIndex < dotsLength; listIndex++) { names[listIndex] = dotsNames.getDataAt(listIndex).isEmpty() ? null : dotsNames.getDataAt(listIndex); } } - RStringVector moreArgsNames = getNames.getNames(moreArgs); + RStringVector moreArgsNames = getNamesMoreArgs.getNames(moreArgs); for (int listIndex = dotsLength; listIndex < dotsLength + moreArgsLength; listIndex++) { values[listIndex] = moreArgs.getDataAt(listIndex - dotsLength); names[listIndex] = moreArgsNames == null ? null : (moreArgsNames.getDataAt(listIndex - dotsLength).isEmpty() ? null : moreArgsNames.getDataAt(listIndex - dotsLength)); @@ -221,17 +246,15 @@ public abstract class Mapply extends RBuiltinNode.Arg3 { return RCallNode.createCall(Lapply.createCallSourceSection(), null, ArgumentsSignature.get(names), syntaxNodes); } - protected ElementNode[] createElementNodeArray(RAbstractListVector dots, RAbstractListVector moreArgs) { - int length = dots.getLength() + moreArgs.getLength(); + protected ElementNode[] createElementNodeArray(int dotsLength, int moreArgsLength, String[] cachedDotsNames, String[] cachedMoreArgsNames) { + int length = dotsLength + moreArgsLength; ElementNode[] elementNodes = new ElementNode[length]; - RStringVector dotsNames = getNames.getNames(dots); - for (int i = 0; i < dots.getLength(); i++) { - elementNodes[i] = insert(new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1), dotsNames == null ? null : (dotsNames.getDataAt(i).isEmpty() ? null : dotsNames.getDataAt(i)))); + for (int i = 0; i < dotsLength; i++) { + elementNodes[i] = insert(new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1), cachedDotsNames == null ? null : (cachedDotsNames[i].isEmpty() ? null : cachedDotsNames[i]))); } - RStringVector moreArgsNames = getNames.getNames(moreArgs); - for (int i = dots.getLength(); i < dots.getLength() + moreArgs.getLength(); i++) { - elementNodes[i] = insert(new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1), - moreArgsNames == null ? null : moreArgsNames.getDataAt(i - dots.getLength()).isEmpty() ? null : moreArgsNames.getDataAt(i - dots.getLength()))); + for (int i = 0; i < moreArgsLength; i++) { + elementNodes[i + dotsLength] = insert( + new ElementNode(VECTOR_ELEMENT_PREFIX + (i + 1 + dotsLength), cachedMoreArgsNames == null ? null : cachedMoreArgsNames[i].isEmpty() ? null : cachedMoreArgsNames[i])); } return elementNodes; } @@ -240,29 +263,34 @@ public abstract class Mapply extends RBuiltinNode.Arg3 { return ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, false); } - protected boolean sameNames(RAbstractListVector list, RAbstractListVector cachedList) { - RStringVector listNames = getNames.getNames(list); - RStringVector cachedListNames = getNames.getNames(cachedList); - if (listNames == null && cachedListNames == null) { + protected boolean sameNames(RAbstractListVector list, String[] cachedNames) { + RStringVector listNames = getNamesDots.getNames(list); + if (listNames == null && cachedNames == null) { return true; - } else if (listNames == null || cachedListNames == null) { + } else if (listNames == null || cachedNames == null) { return false; } else { - - for (int i = 0; i < list.getLength(); i++) { + for (int i = 0; i < cachedNames.length; i++) { String name = listNames.getDataAt(i); - String cachedName = cachedListNames.getDataAt(i); - + String cachedName = cachedNames[i]; if (name == cachedName) { continue; - } else if (name == null || cachedName == null) { - return false; - } else if (!(name.equals(cachedName))) { - return false; + } else { + nonPerfectMatch.enter(); + if (name == null || cachedName == null) { + return false; + } else if (!equals(name, cachedName)) { + return false; + } } } return true; } } + + @TruffleBoundary + private static boolean equals(String name, String cachedName) { + return name.equals(cachedName); + } } } -- GitLab