From 6f4f9eda5e99b9c4c4d4f4fcd9f66702e631fe34 Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Thu, 10 Nov 2016 15:55:39 +0100 Subject: [PATCH] replacement special calls return the rhs value with the exception if it was already evaluated --- .../nodes/builtin/base/infix/UpdateField.java | 4 +- .../builtin/base/infix/UpdateSubscript.java | 2 +- .../control/ReplacementDispatchNode.java | 18 +++- .../r/nodes/control/ReplacementNode.java | 95 ++++++++++++++++++- .../r/nodes/function/RCallSpecialNode.java | 47 +++++---- .../r/runtime/builtins/RSpecialFactory.java | 26 ++++- 6 files changed, 157 insertions(+), 35 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java index 35b1f19c4a..da44b6344d 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java @@ -69,7 +69,7 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { public RList doList(RList list, String field, Object value, @Cached("getIndex(list.getNames(), field)") int index) { if (index == -1) { - throw RSpecialFactory.throwFullCallNeeded(); + throw RSpecialFactory.throwFullCallNeeded(value); } updateCache(list, field); Object sharedValue = value; @@ -89,7 +89,7 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { @SuppressWarnings("unused") @Fallback public void doFallback(Object container, Object field, Object value) { - throw RSpecialFactory.throwFullCallNeeded(); + throw RSpecialFactory.throwFullCallNeeded(value); } private ShareObjectNode getShareObjectNode() { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java index 4d17163f6d..5a68b3998b 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java @@ -128,7 +128,7 @@ abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon { @SuppressWarnings("unused") @Fallback protected static Object setFallback(Object vector, Object index, Object value) { - throw RSpecialFactory.throwFullCallNeeded(); + throw RSpecialFactory.throwFullCallNeeded(value); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java index caecfb41d8..656bb60251 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java @@ -73,15 +73,23 @@ public final class ReplacementDispatchNode extends OperatorNode { @Override public Object execute(VirtualFrame frame) { CompilerDirectives.transferToInterpreterAndInvalidate(); + return create(false).execute(frame); + } + + @Override + public void voidExecute(VirtualFrame frame) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + create(true).voidExecute(frame); + } + public RNode create(boolean isVoid) { RNode replacement; if (lhs instanceof RSyntaxCall) { - replacement = createReplacementNode(); + replacement = createReplacementNode(isVoid); } else { replacement = new WriteVariableSyntaxNode(getLazySourceSection(), operator, lhs.asRSyntaxNode(), rhs, isSuper); } - - return replace(replacement).execute(frame); + return replace(replacement); } @Override @@ -94,7 +102,7 @@ public final class ReplacementDispatchNode extends OperatorNode { return new RSyntaxElement[]{lhs.asRSyntaxNode(), rhs.asRSyntaxNode()}; } - private ReplacementNode createReplacementNode() { + private ReplacementNode createReplacementNode(boolean isVoid) { CompilerAsserts.neverPartOfCompilation(); /* @@ -122,7 +130,7 @@ public final class ReplacementDispatchNode extends OperatorNode { } RSyntaxLookup variable = (RSyntaxLookup) current; ReadVariableNode varRead = createReplacementForVariableUsing(variable, isSuper); - return ReplacementNode.create(getLazySourceSection(), operator, varRead, lhs.asRSyntaxNode(), rhs, calls, variable.getIdentifier(), isSuper, tempNamesStartIndex); + return ReplacementNode.create(getLazySourceSection(), operator, varRead, lhs.asRSyntaxNode(), rhs, calls, variable.getIdentifier(), isSuper, tempNamesStartIndex, isVoid); } private static ReadVariableNode createReplacementForVariableUsing(RSyntaxLookup var, boolean isSuper) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java index e97d211ff5..1e0d45dfc2 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java @@ -65,12 +65,16 @@ abstract class ReplacementNode extends OperatorNode { } public static ReplacementNode create(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls, - String targetVarName, boolean isSuper, int tempNamesStartIndex) { + String targetVarName, boolean isSuper, int tempNamesStartIndex, boolean isVoid) { CompilerAsserts.neverPartOfCompilation(); // Note: if specials are turned off in FastR, onlySpecials will never be true boolean createSpecial = hasOnlySpecialCalls(calls); if (createSpecial) { - return new SpecialReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex); + if (isVoid) { + return new SpecialVoidReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex); + } else { + return new SpecialReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex); + } } else { return new GenericReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex); } @@ -102,7 +106,7 @@ abstract class ReplacementNode extends OperatorNode { argNodes[i] = i == 0 ? newFirstArg : builder.process(arguments[i], codeBuilderContext); } - return RCallSpecialNode.createCallInReplace(fun.getLazySourceSection(), builder.process(fun.getSyntaxLHS(), codeBuilderContext).asRNode(), fun.getSyntaxSignature(), argNodes).asRNode(); + return RCallSpecialNode.createCallInReplace(fun.getLazySourceSection(), builder.process(fun.getSyntaxLHS(), codeBuilderContext).asRNode(), fun.getSyntaxSignature(), argNodes, 0).asRNode(); } /** @@ -144,7 +148,7 @@ abstract class ReplacementNode extends OperatorNode { newArgs[1] = builder.lookup(oldArgs[1].getLazySourceSection(), ((RSyntaxLookup) oldArgs[1]).getIdentifier() + "<-", true); newSyntaxLHS = RCallSpecialNode.createCall(callLHS.getLazySourceSection(), ((RSyntaxNode) callLHS.getSyntaxLHS()).asRNode(), callLHS.getSyntaxSignature(), newArgs); } - return RCallSpecialNode.createCallInReplace(source, newSyntaxLHS.asRNode(), ArgumentsSignature.get(names), argNodes).asRNode(); + return RCallSpecialNode.createCallInReplace(source, newSyntaxLHS.asRNode(), ArgumentsSignature.get(names), argNodes, 0, argNodes.length - 1).asRNode(); } static RLanguage getLanguage(WriteVariableNode wvn) { @@ -193,6 +197,12 @@ abstract class ReplacementNode extends OperatorNode { removeRhs.execute(frame); } + protected final void voidExecuteWithRhs(VirtualFrame frame, Object rhsValue) { + storeRhs.execute(frame, rhsValue); + executeReplacement(frame); + removeRhs.execute(frame); + } + protected abstract void executeReplacement(VirtualFrame frame); @Override @@ -231,7 +241,7 @@ abstract class ReplacementNode extends OperatorNode { RNode extractFunc = target; for (int i = calls.size() - 1; i >= 1; i--) { extractFunc = createSpecialFunctionQuery(calls.get(i), extractFunc.asRSyntaxNode(), codeBuilderContext); - assert extractFunc instanceof RCallSpecialNode; + ((RCallSpecialNode) extractFunc).setPropagateFullCallNeededException(true); } this.replaceCall = (RCallSpecialNode) createFunctionUpdate(source, extractFunc.asRSyntaxNode(), ReadVariableNode.create("*rhs*" + tempNamesStartIndex), calls.get(0), codeBuilderContext); this.replaceCall.setPropagateFullCallNeededException(true); @@ -252,6 +262,81 @@ abstract class ReplacementNode extends OperatorNode { } } + /** + * Replacement that is made of only special calls, if one of the special calls falls back to + * full version, the replacement also falls back to {@link ReplacementNode}. Additionally, this + * type only works if the result of the call (the rhs) is not needed. + */ + private static final class SpecialVoidReplacementNode extends ReplacementNode { + + @Child private RCallSpecialNode replaceCall; + + private final RNode rhs; + private final List<RSyntaxCall> calls; + private final int tempNamesStartIndex; + private final boolean isSuper; + private final String targetVarName; + private final RNode target; + + SpecialVoidReplacementNode(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls, String targetVarName, + boolean isSuper, int tempNamesStartIndex) { + super(source, operator, lhs); + this.target = target; + this.rhs = rhs; + this.calls = calls; + this.targetVarName = targetVarName; + this.isSuper = isSuper; + this.tempNamesStartIndex = tempNamesStartIndex; + + /* + * Creates a replacement that consists only of {@link RCallSpecialNode} calls. + */ + CodeBuilderContext codeBuilderContext = new CodeBuilderContext(tempNamesStartIndex + 2); + RNode extractFunc = target; + for (int i = calls.size() - 1; i >= 1; i--) { + extractFunc = createSpecialFunctionQuery(calls.get(i), extractFunc.asRSyntaxNode(), codeBuilderContext); + ((RCallSpecialNode) extractFunc).setPropagateFullCallNeededException(true); + } + this.replaceCall = (RCallSpecialNode) createFunctionUpdate(source, extractFunc.asRSyntaxNode(), rhs.asRSyntaxNode(), calls.get(0), codeBuilderContext); + this.replaceCall.setPropagateFullCallNeededException(true); + } + + @Override + public Object execute(VirtualFrame frame) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + GenericReplacementNode replacement = new GenericReplacementNode(getLazySourceSection(), operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex); + return replace(replacement).execute(frame); + } + + @Override + public void voidExecute(VirtualFrame frame) { + try { + // Note: the very last call is the actual assignment, e.g. [[<-, if this call's + // argument is shared, it bails out. Moreover, if that call's argument is not + // shared, it could not be extracted from a shared container (list), so we should be + // OK with not calling any other update function and just update the value directly. + replaceCall.execute(frame); + } catch (FullCallNeededException | RecursiveSpecialBailout e) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + GenericReplacementNode replacement = replace(new GenericReplacementNode(getLazySourceSection(), operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex)); + + Object rhsValue = e instanceof FullCallNeededException ? ((FullCallNeededException) e).rhsValue : ((RecursiveSpecialBailout) e).rhsValue; + if (rhsValue == null) { + // we haven't queried the rhs value yet + replacement.voidExecute(frame); + } else { + // rhs was already queried, so pass it along + replacement.voidExecuteWithRhs(frame, rhsValue); + } + } + } + + @Override + public RSyntaxElement[] getSyntaxArguments() { + return new RSyntaxElement[]{lhs, rhs.asRSyntaxNode()}; + } + } + /** * Holds the sequence of nodes created for R's replacement assignment. */ diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java index b9b3f54935..8259598aa4 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java @@ -145,22 +145,22 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode } /** - * This passes {@code true} for the isReplacement parameter and ignores the first argument, - * i.e., does not modify the first argument in any way before passing it to + * This passes {@code true} for the isReplacement parameter and ignores the specified arguments, + * i.e., does not modify them in any way before passing it to * {@link RSpecialFactory#create(ArgumentsSignature, RNode[], boolean)}. */ - public static RSyntaxNode createCallInReplace(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments) { - return createCall(sourceSection, functionNode, signature, arguments, true); + public static RSyntaxNode createCallInReplace(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, int... ignoredArguments) { + return createCall(sourceSection, functionNode, signature, arguments, true, ignoredArguments); } public static RSyntaxNode createCall(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments) { return createCall(sourceSection, functionNode, signature, arguments, false); } - private static RSyntaxNode createCall(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace) { + private static RSyntaxNode createCall(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace, int... ignoredArguments) { RCallSpecialNode special = null; if (useSpecials) { - special = tryCreate(sourceSection, functionNode, signature, arguments, inReplace); + special = tryCreate(sourceSection, functionNode, signature, arguments, inReplace, ignoredArguments); } if (special != null) { return special; @@ -169,14 +169,17 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode } } - private static RCallSpecialNode tryCreate(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace) { + private static RCallSpecialNode tryCreate(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace, int[] ignoredArguments) { RSyntaxNode syntaxFunction = functionNode.asRSyntaxNode(); if (!(syntaxFunction instanceof RSyntaxLookup)) { // LHS is not a simple lookup -> bail out return null; } - for (RSyntaxNode argument : arguments) { - if (!(argument instanceof RSyntaxLookup || argument instanceof RSyntaxConstant || argument instanceof RCallSpecialNode)) { + for (int i = 0; i < arguments.length; i++) { + if (contains(ignoredArguments, i)) { + continue; + } + if (!(arguments[i] instanceof RSyntaxLookup || arguments[i] instanceof RSyntaxConstant || arguments[i] instanceof RCallSpecialNode)) { // argument is not a simple lookup or constant value or another special -> bail out return null; } @@ -194,10 +197,7 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode } RNode[] localArguments = new RNode[arguments.length]; for (int i = 0; i < arguments.length; i++) { - if (inReplace && i == 0) { - if (arguments[i] instanceof RCallSpecialNode) { - ((RCallSpecialNode) arguments[i]).setArgumentIndex(i); - } + if (inReplace && contains(ignoredArguments, i)) { localArguments[i] = arguments[i].asRNode(); } else { if (arguments[i] instanceof RSyntaxLookup) { @@ -222,6 +222,15 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode return new RCallSpecialNode(sourceSection, functionNode, expectedFunction, arguments, signature, special); } + private static boolean contains(int[] ignoredArguments, int index) { + for (int i = 0; i < ignoredArguments.length; i++) { + if (ignoredArguments[i] == index) { + return true; + } + } + return false; + } + @Override public Object execute(VirtualFrame frame, Object function) { try { @@ -232,18 +241,18 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode return special.execute(frame); } catch (RecursiveSpecialBailout bailout) { CompilerDirectives.transferToInterpreterAndInvalidate(); - throwOnRecursiveSpecial(); + throwOnRecursiveSpecial(bailout.rhsValue); return replace(getRCallNode(rewriteSpecialArgument(bailout))).execute(frame, function); } catch (RSpecialFactory.FullCallNeededException e) { CompilerDirectives.transferToInterpreterAndInvalidate(); - throwOnRecursiveSpecial(); + throwOnRecursiveSpecial(e.rhsValue); return replace(getRCallNode()).execute(frame, function); } } - private void throwOnRecursiveSpecial() { + private void throwOnRecursiveSpecial(Object rhsValue) { if (isRecursiveSpecial()) { - throw new RecursiveSpecialBailout(argumentIndex); + throw new RecursiveSpecialBailout(argumentIndex, rhsValue); } } @@ -313,9 +322,11 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode @SuppressWarnings("serial") public static final class RecursiveSpecialBailout extends RuntimeException { public final int argumentIndex; + public final Object rhsValue; - RecursiveSpecialBailout(int argumentIndex) { + RecursiveSpecialBailout(int argumentIndex, Object rhsValue) { this.argumentIndex = argumentIndex; + this.rhsValue = rhsValue; } @Override diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java index d078aab5d1..f5646bc4ac 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java @@ -33,9 +33,26 @@ import com.oracle.truffle.r.runtime.nodes.RNode; * {@link #throwFullCallNeeded()} and the it will be replaced with call to the full blown built-in. */ public interface RSpecialFactory { + + /** + * Signals that the current special call cannot fulfill the requested action because of some + * restriction, and that the call should be rewritten to the generic call. This function must + * not be used if the rhs value of an update call was already computed - in this case, use the + * {@link #throwFullCallNeeded(Object)} function instead. + */ static FullCallNeededException throwFullCallNeeded() { CompilerDirectives.transferToInterpreterAndInvalidate(); - throw FullCallNeededException.INSTANCE; + throw new FullCallNeededException(null); + } + + /** + * Signals that the current special call cannot fulfill the requested action because of some + * restriction, and that the call should be rewritten to the generic call. This function must be + * used when a rhs value of an update call was already computed and must not be recomputed. + */ + static FullCallNeededException throwFullCallNeeded(Object rhsValue) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + throw new FullCallNeededException(rhsValue); } /** @@ -51,10 +68,11 @@ public interface RSpecialFactory { @SuppressWarnings("serial") final class FullCallNeededException extends RuntimeException { - private static RuntimeException INSTANCE = new FullCallNeededException(); - private FullCallNeededException() { - // singleton + public Object rhsValue; + + private FullCallNeededException(Object rhsValue) { + this.rhsValue = rhsValue; } @Override -- GitLab