Skip to content
Snippets Groups Projects
Commit 08b14fd4 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

modify lapply to simulate the correct call structure

parent 882a1279
No related branches found
No related tags found
No related merge requests found
......@@ -14,13 +14,17 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.runtime.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.Frame;
import com.oracle.truffle.api.frame.FrameSlot;
import com.oracle.truffle.api.frame.FrameSlotTypeException;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.LoopNode;
import com.oracle.truffle.api.source.Source;
import com.oracle.truffle.api.source.SourceSection;
import com.oracle.truffle.r.nodes.access.WriteVariableNode;
import com.oracle.truffle.r.nodes.access.WriteVariableNode.Mode;
import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode;
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
......@@ -28,19 +32,25 @@ import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNodeGen;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.LapplyNodeGen.LapplyInternalNodeGen;
import com.oracle.truffle.r.nodes.control.RLengthNode;
import com.oracle.truffle.r.nodes.control.RLengthNodeGen;
import com.oracle.truffle.r.nodes.function.RCallNode;
import com.oracle.truffle.r.runtime.AnonymousFrameVariable;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RBuiltin;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.RSerialize.State;
import com.oracle.truffle.r.runtime.RType;
import com.oracle.truffle.r.runtime.data.RAttributeProfiles;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RFunction;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.env.REnvironment;
import com.oracle.truffle.r.runtime.nodes.InternalRSyntaxNodeChildren;
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
import com.oracle.truffle.r.runtime.nodes.RSourceSectionNode;
import com.oracle.truffle.r.runtime.nodes.RSyntaxCall;
import com.oracle.truffle.r.runtime.nodes.RSyntaxElement;
import com.oracle.truffle.r.runtime.nodes.RSyntaxLookup;
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
/**
* The {@code lapply} builtin. {@code lapply} is an important implicit iterator in R. This
......@@ -66,24 +76,80 @@ public abstract class Lapply extends RBuiltinNode {
return RDataFactory.createList(result, vec.getNames(attrProfiles));
}
public abstract static class LapplyInternalNode extends RBaseNode implements InternalRSyntaxNodeChildren {
private static final class ExtractElementInternal extends RSourceSectionNode implements RSyntaxCall {
private static final String VECTOR_ELEMENT = AnonymousFrameVariable.create("LAPPLY_VEC_ELEM");
protected ExtractElementInternal() {
super(RSyntaxNode.LAZY_DEPARSE);
}
@Child private RLengthNode lengthNode = RLengthNodeGen.create();
@Child private WriteVariableNode writeVectorElement = WriteVariableNode.createAnonymous(VECTOR_ELEMENT, null, Mode.REGULAR);
@Child private ExtractVectorNode extractElementNode = ExtractVectorNodeGen.create(ElementAccessMode.SUBSCRIPT, false);
@Child private RCallNode callNode = createCallNode();
@CompilationFinal private FrameSlot vectorSlot;
@CompilationFinal private FrameSlot indexSlot;
@Override
public Object execute(VirtualFrame frame) {
if (vectorSlot == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
vectorSlot = frame.getFrameDescriptor().findFrameSlot("X");
indexSlot = frame.getFrameDescriptor().findFrameSlot("i");
}
try {
return extractElementNode.apply(frame, frame.getObject(vectorSlot), new Object[]{frame.getInt(indexSlot)}, RRuntime.LOGICAL_TRUE, RRuntime.LOGICAL_TRUE);
} catch (FrameSlotTypeException e) {
CompilerDirectives.transferToInterpreter();
throw RInternalError.shouldNotReachHere("frame type mismatch in lapply");
}
}
@Override
public RSyntaxElement getSyntaxLHS() {
return RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "[[", true);
}
@Override
public ArgumentsSignature getSyntaxSignature() {
return ArgumentsSignature.empty(2);
}
@Override
public RSyntaxElement[] getSyntaxArguments() {
return new RSyntaxElement[]{RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "X", false), RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "i", false)};
}
@Override
public RSyntaxNode substituteImpl(REnvironment env) {
throw RInternalError.shouldNotReachHere();
}
@Override
public void serializeImpl(State state) {
throw RInternalError.shouldNotReachHere();
}
}
public abstract static class LapplyInternalNode extends RBaseNode implements InternalRSyntaxNodeChildren {
protected static final String INDEX_NAME = "i";
protected static final String VECTOR_NAME = "X";
public abstract Object[] execute(VirtualFrame frame, Object vector, RFunction function);
protected static FrameSlot createSlot(Frame frame, String name) {
return frame.getFrameDescriptor().findOrAddFrameSlot(name);
}
@Specialization
protected Object[] cachedLApply(VirtualFrame frame, Object vector, RFunction function) {
protected Object[] cachedLApply(VirtualFrame frame, Object vector, RFunction function, //
@Cached("createSlot(frame, INDEX_NAME)") FrameSlot indexSlot, //
@Cached("createSlot(frame, VECTOR_NAME)") FrameSlot vectorSlot, //
@Cached("create()") RLengthNode lengthNode, //
@Cached("createCallNode()") RCallNode callNode) {
// TODO: R switches to double if x.getLength() is greater than 2^31-1
frame.setObject(vectorSlot, vector);
int length = lengthNode.executeInteger(frame, vector);
Object[] result = new Object[length];
for (int i = 1; i <= length; i++) {
writeVectorElement.execute(frame, extractElementNode.apply(frame, vector, new Object[]{i}, RRuntime.LOGICAL_TRUE, RRuntime.LOGICAL_TRUE));
frame.setInt(indexSlot, i);
result[i - 1] = callNode.execute(frame, function);
}
return result;
......@@ -95,10 +161,10 @@ public abstract class Lapply extends RBuiltinNode {
protected RCallNode createCallNode() {
CompilerAsserts.neverPartOfCompilation();
ReadVariableNode readVector = ReadVariableNode.createSilent(VECTOR_ELEMENT, RType.Any);
ExtractElementInternal element = new ExtractElementInternal();
ReadVariableNode readArgs = ReadVariableNode.createSilent(ArgumentsSignature.VARARG_NAME, RType.Any);
return RCallNode.createCall(createCallSourceSection(), null, ArgumentsSignature.get(null, "..."), readVector, readArgs);
return RCallNode.createCall(createCallSourceSection(), ReadVariableNode.create("FUN"), ArgumentsSignature.get(null, "..."), element, readArgs);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment