diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Lapply.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Lapply.java index 830ee2595ed20a68d9fdf3e9c74b04211d48acc8..21381401026866c1b35231224bf56b32d3c65113 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Lapply.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Lapply.java @@ -14,13 +14,17 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.runtime.RBuiltinKind.INTERNAL; import com.oracle.truffle.api.CompilerAsserts; +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.Frame; +import com.oracle.truffle.api.frame.FrameSlot; +import com.oracle.truffle.api.frame.FrameSlotTypeException; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.LoopNode; import com.oracle.truffle.api.source.Source; import com.oracle.truffle.api.source.SourceSection; -import com.oracle.truffle.r.nodes.access.WriteVariableNode; -import com.oracle.truffle.r.nodes.access.WriteVariableNode.Mode; import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode; import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode; import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode; @@ -28,19 +32,25 @@ import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNodeGen; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.base.LapplyNodeGen.LapplyInternalNodeGen; import com.oracle.truffle.r.nodes.control.RLengthNode; -import com.oracle.truffle.r.nodes.control.RLengthNodeGen; import com.oracle.truffle.r.nodes.function.RCallNode; -import com.oracle.truffle.r.runtime.AnonymousFrameVariable; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.RBuiltin; +import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.RSerialize.State; import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.data.RAttributeProfiles; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RFunction; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.env.REnvironment; import com.oracle.truffle.r.runtime.nodes.InternalRSyntaxNodeChildren; import com.oracle.truffle.r.runtime.nodes.RBaseNode; +import com.oracle.truffle.r.runtime.nodes.RSourceSectionNode; +import com.oracle.truffle.r.runtime.nodes.RSyntaxCall; +import com.oracle.truffle.r.runtime.nodes.RSyntaxElement; +import com.oracle.truffle.r.runtime.nodes.RSyntaxLookup; +import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; /** * The {@code lapply} builtin. {@code lapply} is an important implicit iterator in R. This @@ -66,24 +76,80 @@ public abstract class Lapply extends RBuiltinNode { return RDataFactory.createList(result, vec.getNames(attrProfiles)); } - public abstract static class LapplyInternalNode extends RBaseNode implements InternalRSyntaxNodeChildren { + private static final class ExtractElementInternal extends RSourceSectionNode implements RSyntaxCall { - private static final String VECTOR_ELEMENT = AnonymousFrameVariable.create("LAPPLY_VEC_ELEM"); + protected ExtractElementInternal() { + super(RSyntaxNode.LAZY_DEPARSE); + } - @Child private RLengthNode lengthNode = RLengthNodeGen.create(); - @Child private WriteVariableNode writeVectorElement = WriteVariableNode.createAnonymous(VECTOR_ELEMENT, null, Mode.REGULAR); @Child private ExtractVectorNode extractElementNode = ExtractVectorNodeGen.create(ElementAccessMode.SUBSCRIPT, false); - @Child private RCallNode callNode = createCallNode(); + @CompilationFinal private FrameSlot vectorSlot; + @CompilationFinal private FrameSlot indexSlot; + + @Override + public Object execute(VirtualFrame frame) { + if (vectorSlot == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + vectorSlot = frame.getFrameDescriptor().findFrameSlot("X"); + indexSlot = frame.getFrameDescriptor().findFrameSlot("i"); + } + try { + return extractElementNode.apply(frame, frame.getObject(vectorSlot), new Object[]{frame.getInt(indexSlot)}, RRuntime.LOGICAL_TRUE, RRuntime.LOGICAL_TRUE); + } catch (FrameSlotTypeException e) { + CompilerDirectives.transferToInterpreter(); + throw RInternalError.shouldNotReachHere("frame type mismatch in lapply"); + } + } + + @Override + public RSyntaxElement getSyntaxLHS() { + return RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "[[", true); + } + + @Override + public ArgumentsSignature getSyntaxSignature() { + return ArgumentsSignature.empty(2); + } + + @Override + public RSyntaxElement[] getSyntaxArguments() { + return new RSyntaxElement[]{RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "X", false), RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "i", false)}; + } + + @Override + public RSyntaxNode substituteImpl(REnvironment env) { + throw RInternalError.shouldNotReachHere(); + } + + @Override + public void serializeImpl(State state) { + throw RInternalError.shouldNotReachHere(); + } + } + + public abstract static class LapplyInternalNode extends RBaseNode implements InternalRSyntaxNodeChildren { + + protected static final String INDEX_NAME = "i"; + protected static final String VECTOR_NAME = "X"; public abstract Object[] execute(VirtualFrame frame, Object vector, RFunction function); + protected static FrameSlot createSlot(Frame frame, String name) { + return frame.getFrameDescriptor().findOrAddFrameSlot(name); + } + @Specialization - protected Object[] cachedLApply(VirtualFrame frame, Object vector, RFunction function) { + protected Object[] cachedLApply(VirtualFrame frame, Object vector, RFunction function, // + @Cached("createSlot(frame, INDEX_NAME)") FrameSlot indexSlot, // + @Cached("createSlot(frame, VECTOR_NAME)") FrameSlot vectorSlot, // + @Cached("create()") RLengthNode lengthNode, // + @Cached("createCallNode()") RCallNode callNode) { // TODO: R switches to double if x.getLength() is greater than 2^31-1 + frame.setObject(vectorSlot, vector); int length = lengthNode.executeInteger(frame, vector); Object[] result = new Object[length]; for (int i = 1; i <= length; i++) { - writeVectorElement.execute(frame, extractElementNode.apply(frame, vector, new Object[]{i}, RRuntime.LOGICAL_TRUE, RRuntime.LOGICAL_TRUE)); + frame.setInt(indexSlot, i); result[i - 1] = callNode.execute(frame, function); } return result; @@ -95,10 +161,10 @@ public abstract class Lapply extends RBuiltinNode { protected RCallNode createCallNode() { CompilerAsserts.neverPartOfCompilation(); - ReadVariableNode readVector = ReadVariableNode.createSilent(VECTOR_ELEMENT, RType.Any); + ExtractElementInternal element = new ExtractElementInternal(); ReadVariableNode readArgs = ReadVariableNode.createSilent(ArgumentsSignature.VARARG_NAME, RType.Any); - return RCallNode.createCall(createCallSourceSection(), null, ArgumentsSignature.get(null, "..."), readVector, readArgs); + return RCallNode.createCall(createCallSourceSection(), ReadVariableNode.create("FUN"), ArgumentsSignature.get(null, "..."), element, readArgs); } }