Skip to content
Snippets Groups Projects
Commit 2b05e288 authored by Florian Angerer's avatar Florian Angerer
Browse files

Fix: Fallback into slow path when materializing deeply recursive data

structures.
parent 7baa9464
Branches
No related tags found
No related merge requests found
...@@ -236,10 +236,10 @@ public class CallAndExternalFunctions { ...@@ -236,10 +236,10 @@ public class CallAndExternalFunctions {
return new Object[]{RMissing.instance, RArgsValuesAndNames.EMPTY, RMissing.instance}; return new Object[]{RMissing.instance, RArgsValuesAndNames.EMPTY, RMissing.instance};
} }
private Object[] materializeArgs(VirtualFrame frame, Object[] args) { private Object[] materializeArgs(Object[] args) {
Object[] materializedArgs = new Object[args.length]; Object[] materializedArgs = new Object[args.length];
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
materializedArgs[i] = materializeNode.execute(frame, args[i]); materializedArgs[i] = materializeNode.execute(args[i]);
} }
return materializedArgs; return materializedArgs;
} }
...@@ -666,38 +666,37 @@ public class CallAndExternalFunctions { ...@@ -666,38 +666,37 @@ public class CallAndExternalFunctions {
*/ */
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Specialization(limit = "2", guards = {"cached == symbol", "builtin == null"}) @Specialization(limit = "2", guards = {"cached == symbol", "builtin == null"})
protected Object callNamedFunction(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, Object packageName, protected Object callNamedFunction(RList symbol, RArgsValuesAndNames args, Object packageName,
@Cached("symbol") RList cached, @Cached("symbol") RList cached,
@Cached("lookupBuiltin(symbol)") RExternalBuiltinNode builtin, @Cached("lookupBuiltin(symbol)") RExternalBuiltinNode builtin,
@Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo, @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo,
@Cached("extractSymbolInfo.execute(symbol)") NativeCallInfo nativeCallInfo) { @Cached("extractSymbolInfo.execute(symbol)") NativeCallInfo nativeCallInfo) {
return callRFFINode.dispatch(nativeCallInfo, materializeArgs(frame, args.getArguments())); return callRFFINode.dispatch(nativeCallInfo, materializeArgs(args.getArguments()));
} }
/** /**
* For some reason, the list instance may change, although it carries the same info. For * For some reason, the list instance may change, although it carries the same info. For
* such cases there is this generic version. * such cases there is this generic version.
*/ */
@SuppressWarnings("unused")
@Specialization(replaces = {"callNamedFunction", "doExternal"}) @Specialization(replaces = {"callNamedFunction", "doExternal"})
protected Object callNamedFunctionGeneric(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, Object packageName, protected Object callNamedFunctionGeneric(RList symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") Object packageName,
@Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo) { @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo) {
RExternalBuiltinNode builtin = lookupBuiltin(symbol); RExternalBuiltinNode builtin = lookupBuiltin(symbol);
if (builtin != null) { if (builtin != null) {
throw RInternalError.shouldNotReachHere("Cache for .Calls with FastR reimplementation (lookupBuiltin(...) != null) exceeded the limit"); throw RInternalError.shouldNotReachHere("Cache for .Calls with FastR reimplementation (lookupBuiltin(...) != null) exceeded the limit");
} }
NativeCallInfo nativeCallInfo = extractSymbolInfo.execute(symbol); NativeCallInfo nativeCallInfo = extractSymbolInfo.execute(symbol);
return callRFFINode.dispatch(nativeCallInfo, materializeArgs(frame, args.getArguments())); return callRFFINode.dispatch(nativeCallInfo, materializeArgs(args.getArguments()));
} }
/** /**
* {@code .NAME = string}, no package specified. * {@code .NAME = string}, no package specified.
*/ */
@Specialization @Specialization
protected Object callNamedFunction(VirtualFrame frame, String symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName, protected Object callNamedFunction(String symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName,
@Cached("createRegisteredNativeSymbol(CallNST)") DLL.RegisteredNativeSymbol rns, @Cached("createRegisteredNativeSymbol(CallNST)") DLL.RegisteredNativeSymbol rns,
@Cached("create()") DLL.RFindSymbolNode findSymbolNode) { @Cached("create()") DLL.RFindSymbolNode findSymbolNode) {
return callNamedFunctionWithPackage(frame, symbol, args, null, rns, findSymbolNode); return callNamedFunctionWithPackage(symbol, args, null, rns, findSymbolNode);
} }
/** /**
...@@ -705,19 +704,19 @@ public class CallAndExternalFunctions { ...@@ -705,19 +704,19 @@ public class CallAndExternalFunctions {
* define that symbol. * define that symbol.
*/ */
@Specialization @Specialization
protected Object callNamedFunctionWithPackage(VirtualFrame frame, String symbol, RArgsValuesAndNames args, String packageName, protected Object callNamedFunctionWithPackage(String symbol, RArgsValuesAndNames args, String packageName,
@Cached("createRegisteredNativeSymbol(CallNST)") DLL.RegisteredNativeSymbol rns, @Cached("createRegisteredNativeSymbol(CallNST)") DLL.RegisteredNativeSymbol rns,
@Cached("create()") DLL.RFindSymbolNode findSymbolNode) { @Cached("create()") DLL.RFindSymbolNode findSymbolNode) {
DLL.SymbolHandle func = findSymbolNode.execute(symbol, packageName, rns); DLL.SymbolHandle func = findSymbolNode.execute(symbol, packageName, rns);
if (func == DLL.SYMBOL_NOT_FOUND) { if (func == DLL.SYMBOL_NOT_FOUND) {
throw error(RError.Message.SYMBOL_NOT_IN_TABLE, symbol, "Call", packageName); throw error(RError.Message.SYMBOL_NOT_IN_TABLE, symbol, "Call", packageName);
} }
return callRFFINode.dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), materializeArgs(frame, args.getArguments())); return callRFFINode.dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), materializeArgs(args.getArguments()));
} }
@Specialization @Specialization
protected Object callNamedFunctionWithPackage(VirtualFrame frame, RExternalPtr symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName) { protected Object callNamedFunctionWithPackage(RExternalPtr symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName) {
return callRFFINode.dispatch(new NativeCallInfo("", symbol.getAddr(), null), materializeArgs(frame, args.getArguments())); return callRFFINode.dispatch(new NativeCallInfo("", symbol.getAddr(), null), materializeArgs(args.getArguments()));
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
......
...@@ -24,10 +24,10 @@ package com.oracle.truffle.r.nodes.helpers; ...@@ -24,10 +24,10 @@ package com.oracle.truffle.r.nodes.helpers;
import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.r.nodes.attributes.HasAttributesNode; import com.oracle.truffle.r.nodes.attributes.HasAttributesNode;
import com.oracle.truffle.r.nodes.attributes.IterableAttributeNode; import com.oracle.truffle.r.nodes.attributes.IterableAttributeNode;
...@@ -45,9 +45,6 @@ public abstract class MaterializeNode extends Node { ...@@ -45,9 +45,6 @@ public abstract class MaterializeNode extends Node {
@Child private IterableAttributeNode attributesIt; @Child private IterableAttributeNode attributesIt;
@Child private SetAttributeNode setAttributeNode; @Child private SetAttributeNode setAttributeNode;
@Child private MaterializeNode recursive;
@Child private MaterializeNode recursiveAttr;
private final boolean deep; private final boolean deep;
protected MaterializeNode(boolean deep) { protected MaterializeNode(boolean deep) {
...@@ -58,63 +55,39 @@ public abstract class MaterializeNode extends Node { ...@@ -58,63 +55,39 @@ public abstract class MaterializeNode extends Node {
} }
} }
public abstract Object execute(VirtualFrame frame, Object arg); public abstract Object execute(Object arg);
@Specialization @Specialization
protected RList doList(VirtualFrame frame, RList vec) { protected RList doList(RList vec) {
RList materialized = materializeContents(frame, vec); RList materialized = materializeContents(vec);
materializeAttributes(frame, materialized); materializeAttributes(materialized);
return materialized; return materialized;
} }
private RList materializeContents(VirtualFrame frame, RList list) {
boolean changed = false;
RList materializedContents = null;
for (int i = 0; i < list.getLength(); i++) {
if (recursive == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
recursive = insert(MaterializeNode.create(deep));
}
Object element = list.getDataAt(i);
Object materializedElem = recursive.execute(frame, element);
if (materializedElem != element) {
materializedContents = (RList) list.copy();
changed = true;
}
if (changed && materializedElem != element) {
materializedContents.setDataAt(i, materializedElem);
}
}
if (changed) {
return materializedContents;
}
return list;
}
@Specialization(limit = "LIMIT", guards = {"vec.getClass() == cachedClass"}) @Specialization(limit = "LIMIT", guards = {"vec.getClass() == cachedClass"})
protected RAttributable doAbstractContainerCached(VirtualFrame frame, RAttributable vec, protected RAttributable doAbstractContainerCached(RAttributable vec,
@SuppressWarnings("unused") @Cached("vec.getClass()") Class<?> cachedClass) { @SuppressWarnings("unused") @Cached("vec.getClass()") Class<?> cachedClass) {
if (vec instanceof RList) { if (vec instanceof RList) {
return doList(frame, (RList) vec); return doList((RList) vec);
} else if (vec instanceof RAbstractContainer) { } else if (vec instanceof RAbstractContainer) {
RAbstractContainer materialized = ((RAbstractContainer) vec).materialize(); RAbstractContainer materialized = ((RAbstractContainer) vec).materialize();
materializeAttributes(frame, materialized); materializeAttributes(materialized);
return materialized; return materialized;
} }
materializeAttributes(frame, vec); materializeAttributes(vec);
return vec; return vec;
} }
@Specialization(replaces = "doAbstractContainerCached") @Specialization(replaces = "doAbstractContainerCached")
protected RAttributable doAbstractContainer(VirtualFrame frame, RAttributable vec) { protected RAttributable doAbstractContainer(RAttributable vec) {
if (vec instanceof RList) { if (vec instanceof RList) {
return doList(frame, (RList) vec); return doList((RList) vec);
} else if (vec instanceof RAbstractContainer) { } else if (vec instanceof RAbstractContainer) {
RAbstractContainer materialized = ((RAbstractContainer) vec).materialize(); RAbstractContainer materialized = ((RAbstractContainer) vec).materialize();
materializeAttributes(frame, materialized); materializeAttributes(materialized);
return materialized; return materialized;
} }
materializeAttributes(frame, vec); materializeAttributes(vec);
return vec; return vec;
} }
...@@ -123,17 +96,35 @@ public abstract class MaterializeNode extends Node { ...@@ -123,17 +96,35 @@ public abstract class MaterializeNode extends Node {
return o; return o;
} }
private void materializeAttributes(VirtualFrame frame, RAttributable materialized) { private RList materializeContents(RList list) {
boolean changed = false;
RList materializedContents = null;
for (int i = 0; i < list.getLength(); i++) {
Object element = list.getDataAt(i);
Object materializedElem = doGenericSlowPath(element);
if (materializedElem != element) {
materializedContents = (RList) list.copy();
changed = true;
}
if (changed && materializedElem != element) {
materializedContents.setDataAt(i, materializedElem);
}
}
if (changed) {
return materializedContents;
}
return list;
}
private void materializeAttributes(RAttributable materialized) {
// TODO we could further optimize by first checking for fixed/special attributes // TODO we could further optimize by first checking for fixed/special attributes
if (deep && hasAttributes.execute(materialized)) { if (deep && hasAttributes.execute(materialized)) {
if (attributesIt == null) { if (attributesIt == null) {
assert recursiveAttr == null;
CompilerDirectives.transferToInterpreterAndInvalidate(); CompilerDirectives.transferToInterpreterAndInvalidate();
attributesIt = insert(IterableAttributeNode.create()); attributesIt = insert(IterableAttributeNode.create());
recursiveAttr = insert(MaterializeNode.create(deep));
} }
for (RAttribute attr : attributesIt.execute(materialized)) { for (RAttribute attr : attributesIt.execute(materialized)) {
Object materializedAttr = recursiveAttr.execute(frame, attr.getValue()); Object materializedAttr = doGenericSlowPath(attr.getValue());
if (materializedAttr != attr.getValue()) { if (materializedAttr != attr.getValue()) {
if (setAttributeNode == null) { if (setAttributeNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate(); CompilerDirectives.transferToInterpreterAndInvalidate();
...@@ -145,6 +136,14 @@ public abstract class MaterializeNode extends Node { ...@@ -145,6 +136,14 @@ public abstract class MaterializeNode extends Node {
} }
} }
@TruffleBoundary
private Object doGenericSlowPath(Object element) {
if (element instanceof RAttributable) {
return doAbstractContainer((RAttributable) element);
}
return element;
}
public static MaterializeNode create(boolean deep) { public static MaterializeNode create(boolean deep) {
return MaterializeNodeGen.create(deep); return MaterializeNodeGen.create(deep);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment