Skip to content
Snippets Groups Projects
Commit cd0b4606 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

add "special" hooks for frequently used primitive functions

parent 610c72f4
No related branches found
No related tags found
No related merge requests found
Showing with 321 additions and 13 deletions
......@@ -156,12 +156,16 @@ public abstract class RBuiltinPackage {
}
protected void add(Class<?> builtinClass, Function<RNode[], RBuiltinNode> constructor) {
add(builtinClass, constructor, null);
}
protected void add(Class<?> builtinClass, Function<RNode[], RBuiltinNode> constructor, Function<RNode[], RNode> specialCall) {
RBuiltin annotation = builtinClass.getAnnotation(RBuiltin.class);
String[] parameterNames = annotation.parameterNames();
parameterNames = Arrays.stream(parameterNames).map(n -> n.isEmpty() ? null : n).toArray(String[]::new);
ArgumentsSignature signature = ArgumentsSignature.get(parameterNames);
putBuiltin(new RBuiltinFactory(annotation.name(), builtinClass, annotation.visibility(), annotation.aliases(), annotation.kind(), signature, annotation.nonEvalArgs(), annotation.splitCaller(),
annotation.alwaysSplit(), annotation.dispatch(), constructor, annotation.behavior()));
annotation.alwaysSplit(), annotation.dispatch(), constructor, annotation.behavior(), specialCall));
}
}
......@@ -33,6 +33,7 @@ import com.oracle.truffle.r.nodes.binary.BinaryBooleanScalarNodeGen;
import com.oracle.truffle.r.nodes.binary.ColonNode;
import com.oracle.truffle.r.nodes.binary.ColonNodeGen;
import com.oracle.truffle.r.nodes.builtin.RBuiltinPackage;
import com.oracle.truffle.r.nodes.builtin.base.InfixFunctions.AccessArraySubscriptSpecialBuiltin;
import com.oracle.truffle.r.nodes.builtin.base.fastpaths.AssignFastPathNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.fastpaths.ExistsFastPathNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.fastpaths.GetFastPathNodeGen;
......@@ -386,7 +387,7 @@ public class BasePackage extends RBuiltinPackage {
add(IConv.class, IConvNodeGen::create);
add(Identical.class, Identical::create);
add(NumericalFunctions.Im.class, NumericalFunctionsFactory.ImNodeGen::create);
add(InfixFunctions.AccessArraySubscriptBuiltin.class, InfixFunctionsFactory.AccessArraySubscriptBuiltinNodeGen::create);
add(InfixFunctions.AccessArraySubscriptBuiltin.class, InfixFunctionsFactory.AccessArraySubscriptBuiltinNodeGen::create, AccessArraySubscriptSpecialBuiltin::create);
add(InfixFunctions.AccessArraySubscriptDefaultBuiltin.class, InfixFunctionsFactory.AccessArraySubscriptBuiltinNodeGen::create);
add(InfixFunctions.AccessArraySubsetBuiltin.class, InfixFunctionsFactory.AccessArraySubsetBuiltinNodeGen::create);
add(InfixFunctions.AccessArraySubsetDefaultBuiltin.class, InfixFunctionsFactory.AccessArraySubsetBuiltinNodeGen::create);
......
......@@ -45,9 +45,13 @@ import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.InfixFunctionsFactory.AccessArraySubscriptSpecialBuiltinNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.InfixFunctionsFactory.PromiseEvaluatorNodeGen;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
import com.oracle.truffle.r.nodes.function.PromiseHelperNode;
import com.oracle.truffle.r.nodes.function.RCallNode;
import com.oracle.truffle.r.nodes.function.RCallSpecialNode;
import com.oracle.truffle.r.nodes.unary.CastListNode;
import com.oracle.truffle.r.nodes.unary.CastListNodeGen;
import com.oracle.truffle.r.runtime.RDeparse;
......@@ -66,6 +70,8 @@ import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RPromise;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractListVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
......@@ -182,6 +188,59 @@ public class InfixFunctions {
}
@NodeChild(value = "arguments", type = RNode[].class)
public abstract static class AccessArraySubscriptSpecialBuiltin extends RNode {
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
public static RNode create(RNode[] arguments) {
return arguments.length == 2 ? AccessArraySubscriptSpecialBuiltinNodeGen.create(arguments) : null;
}
protected boolean simpleVector(RAbstractVector vector) {
return classHierarchy.execute(vector) == null;
}
protected static boolean inIntRange(RAbstractVector vector, int index) {
return index >= 1 && index <= vector.getLength();
}
protected static boolean inDoubleRange(RAbstractVector vector, double index) {
return index >= 1 && index <= vector.getLength();
}
private static int toInt(double index) {
int i = (int) index;
return i == 0 ? 1 : i - 1;
}
@Specialization(guards = {"simpleVector(vector)", "inIntRange(vector, index)"})
protected static int access(RAbstractIntVector vector, int index) {
return vector.getDataAt(index - 1);
}
@Specialization(guards = {"simpleVector(vector)", "inIntRange(vector, index)"})
protected static double access(RAbstractDoubleVector vector, int index) {
return vector.getDataAt(index - 1);
}
@Specialization(guards = {"simpleVector(vector)", "inDoubleRange(vector, index)"})
protected static int access(RAbstractIntVector vector, double index) {
return vector.getDataAt(toInt(index));
}
@Specialization(guards = {"simpleVector(vector)", "inDoubleRange(vector, index)"})
protected static double access(RAbstractDoubleVector vector, double index) {
return vector.getDataAt(toInt(index));
}
@SuppressWarnings("unused")
@Fallback
protected static Object access(Object vector, Object index) {
throw RCallSpecialNode.fullCallNeeded();
}
}
@RBuiltin(name = "[[", kind = PRIMITIVE, parameterNames = {"", "...", "exact", "drop"}, dispatch = INTERNAL_GENERIC, behavior = PURE)
public abstract static class AccessArraySubscriptBuiltin extends AccessArrayBuiltin {
......
......@@ -47,6 +47,7 @@ import com.oracle.truffle.r.nodes.function.FunctionDefinitionNode;
import com.oracle.truffle.r.nodes.function.FunctionExpressionNode;
import com.oracle.truffle.r.nodes.function.PostProcessArgumentsNode;
import com.oracle.truffle.r.nodes.function.RCallNode;
import com.oracle.truffle.r.nodes.function.RCallSpecialNode;
import com.oracle.truffle.r.nodes.function.SaveArgumentsNode;
import com.oracle.truffle.r.nodes.function.WrapDefaultArgumentNode;
import com.oracle.truffle.r.nodes.unary.GetNonSharedNodeGen;
......@@ -145,7 +146,7 @@ public final class RASTBuilder implements RCodeBuilder<RSyntaxNode> {
arg -> (arg.value == null && arg.name == null) ? ConstantNode.create(arg.source == null ? RSyntaxNode.SOURCE_UNAVAILABLE : arg.source, REmpty.instance) : arg.value).toArray(
RSyntaxNode[]::new);
return RCallNode.createCall(source, lhs.asRNode(), signature, nodes);
return RCallSpecialNode.createCall(source, lhs.asRNode(), signature, nodes);
}
private RSyntaxNode createReplacement(SourceSection source, String operator, boolean isSuper, RSyntaxNode replacementLhs, RSyntaxNode replacementRhs) {
......
......@@ -38,8 +38,8 @@ public final class RBuiltinFactory extends RBuiltinDescriptor {
private final Function<RNode[], RBuiltinNode> constructor;
RBuiltinFactory(String name, Class<?> builtinNodeClass, RVisibility visibility, String[] aliases, RBuiltinKind kind, ArgumentsSignature signature, int[] nonEvalArgs, boolean splitCaller,
boolean alwaysSplit, RDispatch dispatch, Function<RNode[], RBuiltinNode> constructor, RBehavior behavior) {
super(name, builtinNodeClass, visibility, aliases, kind, signature, nonEvalArgs, splitCaller, alwaysSplit, dispatch, behavior);
boolean alwaysSplit, RDispatch dispatch, Function<RNode[], RBuiltinNode> constructor, RBehavior behavior, Function<RNode[], RNode> specialCall) {
super(name, builtinNodeClass, visibility, aliases, kind, signature, nonEvalArgs, splitCaller, alwaysSplit, dispatch, behavior, specialCall);
this.constructor = constructor;
}
......
......@@ -67,6 +67,7 @@ import com.oracle.truffle.r.nodes.function.call.CallRFunctionNode;
import com.oracle.truffle.r.nodes.function.call.PrepareArguments;
import com.oracle.truffle.r.nodes.function.visibility.SetVisibilityNode;
import com.oracle.truffle.r.nodes.profile.TruffleBoundaryNode;
import com.oracle.truffle.r.nodes.profile.VectorLengthProfile;
import com.oracle.truffle.r.nodes.unary.CastNode;
import com.oracle.truffle.r.runtime.Arguments;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
......@@ -615,9 +616,13 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
@Override
public void serializeImpl(RSerialize.State state) {
serializeImpl(state, getFunctionNode(), arguments, signature);
}
public static void serializeImpl(RSerialize.State state, RNode functionNode, RSyntaxNode[] arguments, ArgumentsSignature signature) {
state.setAsLangType();
state.serializeNodeSetCar(getFunctionNode());
if (isColon(getFunctionNode())) {
state.serializeNodeSetCar(functionNode);
if (isColon(functionNode)) {
// special case, have to translate Identifier names to Symbols
RSyntaxNode arg0 = arguments[0];
RSyntaxNode arg1 = arguments[1];
......@@ -636,7 +641,7 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
state.linkPairList(2);
state.setCdr(state.closePairList());
} else {
RSyntaxNode f = getFunctionNode().asRSyntaxNode();
RSyntaxNode f = functionNode.asRSyntaxNode();
boolean infixFieldAccess = false;
if (f instanceof RSyntaxLookup) {
RSyntaxLookup lookup = (RSyntaxLookup) f;
......@@ -1014,13 +1019,33 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
return result;
}
private final VectorLengthProfile varArgProfile = VectorLengthProfile.create();
private void forcePromises(VirtualFrame frame, RArgsValuesAndNames varArgs) {
if (varArgsPromiseHelper == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
varArgsPromiseHelper = insert(new PromiseCheckHelperNode());
}
varArgProfile.profile(varArgs.getLength());
int cachedLength = varArgProfile.getCachedLength();
if (cachedLength >= 0) {
forcePromisesUnrolled(frame, varArgs, cachedLength);
} else {
forcePromisesDynamic(frame, varArgs);
}
}
@ExplodeLoop
private void forcePromisesUnrolled(VirtualFrame frame, RArgsValuesAndNames varArgs, int length) {
Object[] array = varArgs.getArguments();
for (int i = 0; i < length; i++) {
array[i] = varArgsPromiseHelper.checkEvaluate(frame, array[i]);
}
}
private void forcePromisesDynamic(VirtualFrame frame, RArgsValuesAndNames varArgs) {
Object[] array = varArgs.getArguments();
for (int i = 0; i < array.length; i++) {
if (varArgsPromiseHelper == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
varArgsPromiseHelper = insert(new PromiseCheckHelperNode());
}
array[i] = varArgsPromiseHelper.checkEvaluate(frame, array[i]);
}
}
......
/*
* Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.nodes.function;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.source.SourceSection;
import com.oracle.truffle.r.nodes.access.variables.LocalReadVariableNode;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RDeparse;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.RSerialize.State;
import com.oracle.truffle.r.runtime.builtins.RBuiltinDescriptor;
import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.data.RFunction;
import com.oracle.truffle.r.runtime.data.RPromise;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.nodes.RSyntaxCall;
import com.oracle.truffle.r.runtime.nodes.RSyntaxElement;
import com.oracle.truffle.r.runtime.nodes.RSyntaxLookup;
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
final class PeekLocalVariableNode extends RNode {
@Child private LocalReadVariableNode read;
private final ConditionProfile isPromiseProfile = ConditionProfile.createBinaryProfile();
PeekLocalVariableNode(String name) {
this.read = LocalReadVariableNode.create(name, false);
}
@Override
public Object execute(VirtualFrame frame) {
Object value = read.execute(frame);
if (value == null) {
throw RCallSpecialNode.fullCallNeeded();
}
if (isPromiseProfile.profile(value instanceof RPromise)) {
RPromise promise = (RPromise) value;
if (!promise.isEvaluated()) {
throw RCallSpecialNode.fullCallNeeded();
}
return promise.getValue();
}
return value;
}
}
public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode, RSyntaxCall {
public static final RuntimeException FULL_CALL_NEEDED = new FullCallNeededException();
// currently cannot be RSourceSectionNode because of TruffleDSL restrictions
@CompilationFinal private SourceSection sourceSectionR;
@Override
public void setSourceSection(SourceSection sourceSection) {
assert sourceSection != null;
this.sourceSectionR = sourceSection;
}
@Override
public SourceSection getSourceSection() {
return sourceSectionR;
}
@SuppressWarnings("serial")
private static class FullCallNeededException extends RuntimeException {
@Override
public synchronized Throwable fillInStackTrace() {
return null;
}
}
public static RuntimeException fullCallNeeded() {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw FULL_CALL_NEEDED;
}
@Child private ForcePromiseNode functionNode;
@Child private RNode special;
private final RSyntaxNode[] arguments;
private final ArgumentsSignature signature;
private final RFunction expectedFunction;
private RCallSpecialNode(SourceSection sourceSection, RNode functionNode, RFunction expectedFunction, RSyntaxNode[] arguments, ArgumentsSignature signature, RNode special) {
this.sourceSectionR = sourceSection;
this.expectedFunction = expectedFunction;
this.special = special;
this.functionNode = new ForcePromiseNode(functionNode);
this.arguments = arguments;
this.signature = signature;
}
public static RSyntaxNode createCall(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments) {
RCallSpecialNode special = tryCreate(sourceSection, functionNode, signature, arguments);
if (special != null) {
if (sourceSection == RSyntaxNode.EAGER_DEPARSE) {
RDeparse.ensureSourceSection(special);
}
return special;
} else {
return RCallNode.createCall(sourceSection, functionNode, signature, arguments);
}
}
private static RCallSpecialNode tryCreate(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments) {
if (signature.getNonNullCount() > 0) {
// complex signature -> bail out
return null;
}
RSyntaxNode syntaxFunction = functionNode.asRSyntaxNode();
if (!(syntaxFunction instanceof RSyntaxLookup)) {
// LHS is not a simple lookup -> bail out
return null;
}
for (RSyntaxNode argument : arguments) {
if (!(argument instanceof RSyntaxLookup)) {
// argument is not a simple lookup -> bail out
return null;
}
}
String name = ((RSyntaxLookup) syntaxFunction).getIdentifier();
RBuiltinDescriptor builtinDescriptor = RContext.lookupBuiltinDescriptor(name);
if (builtinDescriptor == null || builtinDescriptor.getSpecialCall() == null) {
// no builtin or no special call definition -> bail out
return null;
}
RNode[] localArguments = new RNode[arguments.length];
for (int i = 0; i < arguments.length; i++) {
localArguments[i] = new PeekLocalVariableNode(((RSyntaxLookup) arguments[i]).getIdentifier());
}
RNode special = builtinDescriptor.getSpecialCall().apply(localArguments);
if (special == null) {
// the factory refused to create a special call -> bail out
return null;
}
RFunction expectedFunction = RContext.lookupBuiltin(name);
RInternalError.guarantee(expectedFunction != null);
return new RCallSpecialNode(sourceSection, functionNode, expectedFunction, arguments, signature, special);
}
@Override
public Object execute(VirtualFrame frame, Object function) {
try {
if (function != expectedFunction) {
// the actual function differs from the expected function
throw RCallSpecialNode.fullCallNeeded();
}
return special.execute(frame);
} catch (FullCallNeededException e) {
CompilerDirectives.transferToInterpreterAndInvalidate();
RCallNode call = RCallNode.createCall(sourceSectionR, functionNode == null ? null : functionNode.getValueNode(), signature, arguments);
return replace(call).execute(frame, function);
}
}
@Override
public Object execute(VirtualFrame frame) {
return execute(frame, functionNode.execute(frame));
}
@Override
public void serializeImpl(State state) {
RCallNode.serializeImpl(state, functionNode.getValueNode(), arguments, signature);
}
@Override
public RSyntaxElement getSyntaxLHS() {
ForcePromiseNode func = functionNode;
return func == null || func.getValueNode() == null ? RSyntaxLookup.createDummyLookup(RSyntaxNode.LAZY_DEPARSE, "FUN", true) : func.getValueNode().asRSyntaxNode();
}
@Override
public ArgumentsSignature getSyntaxSignature() {
return signature == null ? ArgumentsSignature.empty(1) : signature;
}
@Override
public RSyntaxElement[] getSyntaxArguments() {
return arguments == null ? new RSyntaxElement[]{RSyntaxLookup.createDummyLookup(RSyntaxNode.LAZY_DEPARSE, "...", false)} : arguments;
}
}
......@@ -23,12 +23,15 @@
package com.oracle.truffle.r.runtime.builtins;
import java.util.Arrays;
import java.util.function.Function;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.PrimitiveMethodsInfo;
import com.oracle.truffle.r.runtime.RDispatch;
import com.oracle.truffle.r.runtime.RVisibility;
import com.oracle.truffle.r.runtime.nodes.RNode;
public abstract class RBuiltinDescriptor {
......@@ -45,12 +48,14 @@ public abstract class RBuiltinDescriptor {
private final boolean alwaysSplit;
private final RDispatch dispatch;
private final RBehavior behavior;
private final Function<RNode[], RNode> specialCall;
private final int primitiveMethodIndex;
@CompilationFinal private final boolean[] evaluatesArgument;
public RBuiltinDescriptor(String name, Class<?> builtinNodeClass, RVisibility visibility, String[] aliases, RBuiltinKind kind, ArgumentsSignature signature, int[] nonEvalArgs, boolean splitCaller,
boolean alwaysSplit, RDispatch dispatch, RBehavior behavior) {
boolean alwaysSplit, RDispatch dispatch, RBehavior behavior, Function<RNode[], RNode> specialCall) {
this.specialCall = specialCall;
this.name = name.intern();
this.builtinNodeClass = builtinNodeClass;
this.visibility = visibility;
......@@ -131,4 +136,8 @@ public abstract class RBuiltinDescriptor {
public RBehavior getBehavior() {
return behavior;
}
public Function<RNode[], RNode> getSpecialCall() {
return specialCall;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment