Skip to content
Snippets Groups Projects
Commit b445161e authored by Lukas Stadler's avatar Lukas Stadler
Browse files

explicitly configure visibility for fast paths, fast paths subclass RBaseNode instead of RNode

parent 96dd3d1c
Branches
No related tags found
No related merge requests found
......@@ -22,6 +22,8 @@
*/
package com.oracle.truffle.r.nodes.builtin.base;
import java.util.function.Supplier;
import com.oracle.truffle.api.frame.MaterializedFrame;
import com.oracle.truffle.r.nodes.RRootNode;
import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode;
......@@ -78,6 +80,7 @@ import com.oracle.truffle.r.nodes.builtin.fastr.FastrDqrls;
import com.oracle.truffle.r.nodes.builtin.fastr.FastrDqrlsNodeGen;
import com.oracle.truffle.r.nodes.unary.UnaryNotNode;
import com.oracle.truffle.r.nodes.unary.UnaryNotNodeGen;
import com.oracle.truffle.r.runtime.RVisibility;
import com.oracle.truffle.r.runtime.builtins.FastPathFactory;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RFunction;
......@@ -640,7 +643,11 @@ public class BasePackage extends RBuiltinPackage {
((RRootNode) function.getRootNode()).setFastPath(factory);
}
private static void addFastPath(MaterializedFrame baseFrame, String name, java.util.function.Supplier<RFastPathNode> factory, Class<?> builtinNodeClass) {
private static void addFastPath(MaterializedFrame baseFrame, String name, Supplier<RFastPathNode> factory, RVisibility visibility) {
addFastPath(baseFrame, name, FastPathFactory.fromVisibility(visibility, factory));
}
private static void addFastPath(MaterializedFrame baseFrame, String name, Supplier<RFastPathNode> factory, Class<?> builtinNodeClass) {
RBuiltin builtin = builtinNodeClass.getAnnotation(RBuiltin.class);
addFastPath(baseFrame, name, FastPathFactory.fromRBuiltin(builtin, factory));
}
......@@ -648,16 +655,16 @@ public class BasePackage extends RBuiltinPackage {
@Override
public void loadOverrides(MaterializedFrame baseFrame) {
super.loadOverrides(baseFrame);
addFastPath(baseFrame, "matrix", () -> MatrixFastPathNodeGen.create(null), Matrix.class);
addFastPath(baseFrame, "setdiff", () -> SetDiffFastPathNodeGen.create(null));
addFastPath(baseFrame, "get", () -> GetFastPathNodeGen.create(null));
addFastPath(baseFrame, "exists", () -> ExistsFastPathNodeGen.create(null), Exists.class);
addFastPath(baseFrame, "assign", () -> AssignFastPathNodeGen.create(null), Assign.class);
addFastPath(baseFrame, "is.element", () -> IsElementFastPathNodeGen.create(null));
addFastPath(baseFrame, "integer", () -> IntegerFastPathNodeGen.create(null));
addFastPath(baseFrame, "numeric", () -> DoubleFastPathNodeGen.create(null));
addFastPath(baseFrame, "double", () -> DoubleFastPathNodeGen.create(null));
addFastPath(baseFrame, "intersect", () -> IntersectFastPathNodeGen.create(null));
addFastPath(baseFrame, "matrix", MatrixFastPathNodeGen::create, Matrix.class);
addFastPath(baseFrame, "setdiff", SetDiffFastPathNodeGen::create, RVisibility.ON);
addFastPath(baseFrame, "get", GetFastPathNodeGen::create, RVisibility.ON);
addFastPath(baseFrame, "exists", ExistsFastPathNodeGen::create, Exists.class);
addFastPath(baseFrame, "assign", AssignFastPathNodeGen::create, Assign.class);
addFastPath(baseFrame, "is.element", IsElementFastPathNodeGen::create, RVisibility.ON);
addFastPath(baseFrame, "integer", IntegerFastPathNodeGen::create, RVisibility.ON);
addFastPath(baseFrame, "numeric", DoubleFastPathNodeGen::create, RVisibility.ON);
addFastPath(baseFrame, "double", DoubleFastPathNodeGen::create, RVisibility.ON);
addFastPath(baseFrame, "intersect", IntersectFastPathNodeGen::create, RVisibility.ON);
addFastPath(baseFrame, "pmax", FastPathFactory.EVALUATE_ARGS);
addFastPath(baseFrame, "pmin", FastPathFactory.EVALUATE_ARGS);
addFastPath(baseFrame, "cbind", FastPathFactory.FORCED_EAGER_ARGS);
......
......@@ -32,6 +32,7 @@ import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.nodes.RFastPathNode;
import com.oracle.truffle.r.runtime.nodes.RNode;
public abstract class IntersectFastPath extends RFastPathNode {
......@@ -44,7 +45,7 @@ public abstract class IntersectFastPath extends RFastPathNode {
@Cached("createBinaryProfile()") ConditionProfile resultLengthMatchProfile) {
int xLength = x.getLength();
int yLength = y.getLength();
reportWork(xLength + yLength);
RNode.reportWork(this, xLength + yLength);
int count = 0;
int[] result = EMPTY_INT_ARRAY;
......
......@@ -1046,7 +1046,8 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
if (fastPath != null) {
Object result = fastPath.execute(frame, orderedArguments.getArguments());
if (result != null) {
RContext.getInstance().setVisible(this.fastPathVisibility);
assert fastPathVisibility != null;
RContext.getInstance().setVisible(fastPathVisibility);
return result;
}
CompilerDirectives.transferToInterpreterAndInvalidate();
......
......@@ -17,7 +17,6 @@ import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.FastROptions;
import com.oracle.truffle.r.runtime.RArguments.S3Args;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
......@@ -39,7 +38,6 @@ public final class UseMethodInternalNode extends RNode {
}
public Object execute(VirtualFrame frame, RStringVector type, Object[] arguments) {
RContext.getInstance().setVisible(true);
Result lookupResult = lookup.execute(frame, generic, type, null, frame.materialize(), null);
if (wrap) {
assert arguments != null;
......
......@@ -38,10 +38,30 @@ import com.oracle.truffle.r.runtime.nodes.RFastPathNode;
* the function is invoked, the fast path is invoked first and only if it returns {@code null}, then
* the original implementation is invoked.
*/
@FunctionalInterface
public interface FastPathFactory {
FastPathFactory EVALUATE_ARGS = () -> null;
FastPathFactory EVALUATE_ARGS = new FastPathFactory() {
@Override
public RFastPathNode create() {
return null;
}
@Override
public RVisibility getVisibility() {
return null;
}
@Override
public boolean evaluatesArgument(int index) {
return true;
}
@Override
public boolean forcedEagerPromise(int index) {
return false;
}
};
FastPathFactory FORCED_EAGER_ARGS = new FastPathFactory() {
......@@ -50,6 +70,11 @@ public interface FastPathFactory {
return null;
}
@Override
public RVisibility getVisibility() {
return null;
}
@Override
public boolean evaluatesArgument(int index) {
return false;
......@@ -83,23 +108,46 @@ public interface FastPathFactory {
}
return true;
}
@Override
public boolean forcedEagerPromise(int index) {
return false;
}
};
}
RFastPathNode create();
static FastPathFactory fromVisibility(RVisibility visibility, Supplier<RFastPathNode> factory) {
return new FastPathFactory() {
@Override
public RFastPathNode create() {
return factory.get();
}
default boolean evaluatesArgument(@SuppressWarnings("unused") int index) {
return true;
}
@Override
public RVisibility getVisibility() {
return visibility;
}
@Override
public boolean evaluatesArgument(int index) {
return true;
}
default boolean forcedEagerPromise(@SuppressWarnings("unused") int index) {
return false;
@Override
public boolean forcedEagerPromise(int index) {
return false;
}
};
}
RFastPathNode create();
boolean evaluatesArgument(int index);
boolean forcedEagerPromise(int index);
/**
* Visibility of the output. This corresponds to {@link RBuiltin#visibility()}
*/
default RVisibility getVisibility() {
return RVisibility.ON;
}
RVisibility getVisibility();
}
......@@ -26,6 +26,7 @@ import java.util.Arrays;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.FastROptions;
import com.oracle.truffle.r.runtime.RVisibility;
import com.oracle.truffle.r.runtime.builtins.FastPathFactory;
final class EvaluatedArgumentsFastPath implements FastPathFactory {
......@@ -41,6 +42,11 @@ final class EvaluatedArgumentsFastPath implements FastPathFactory {
return null;
}
@Override
public RVisibility getVisibility() {
return null;
}
@Override
public boolean evaluatesArgument(int index) {
return false;
......
......@@ -22,11 +22,9 @@
*/
package com.oracle.truffle.r.runtime.nodes;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.frame.VirtualFrame;
@NodeChild(value = "arguments", type = RNode[].class)
public abstract class RFastPathNode extends RNode {
public abstract class RFastPathNode extends RBaseNode {
public abstract Object execute(VirtualFrame frame, Object... args);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment