From 6adc8369c2ff314e615886879d1f15bd95c53b8f Mon Sep 17 00:00:00 2001 From: Zbynek Slajchrt <zbynek.slajchrt@oracle.com> Date: Wed, 8 Aug 2018 15:03:37 +0200 Subject: [PATCH] A few issues fixed when enabling randomForest --- .../r/ffi/impl/nodes/RForceAndCallNode.java | 3 --- .../r/nodes/builtin/base/UpdateNames.java | 27 ++++++++++++++++++- .../variables/LocalReadVariableNode.java | 7 +++-- .../builtins/TestBuiltin_namesassign.java | 6 +++++ 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/RForceAndCallNode.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/RForceAndCallNode.java index 83a081c08d..9d4e21f1b3 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/RForceAndCallNode.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/RForceAndCallNode.java @@ -31,8 +31,6 @@ import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.profiles.ValueProfile; -import com.oracle.truffle.r.nodes.InlineCacheNode; -import com.oracle.truffle.r.nodes.InlineCacheNodeGen; import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode; import com.oracle.truffle.r.nodes.function.PromiseHelperNode; import com.oracle.truffle.r.nodes.function.RCallerHelper; @@ -58,7 +56,6 @@ public abstract class RForceAndCallNode extends RBaseNode { return RForceAndCallNodeGen.create(); } - @Child private InlineCacheNode closureEvalNode = InlineCacheNodeGen.create(10); @Child private PromiseHelperNode promiseHelper = new PromiseHelperNode(); public abstract Object executeObject(Object e, Object f, int n, Object env); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateNames.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateNames.java index c421919dd7..0a00020df6 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateNames.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateNames.java @@ -31,23 +31,29 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.helpers.RFactorNodes.GetLevels; import com.oracle.truffle.r.nodes.unary.CastStringNode; import com.oracle.truffle.r.nodes.unary.CastStringNodeGen; import com.oracle.truffle.r.nodes.unary.GetNonSharedNode; +import com.oracle.truffle.r.nodes.unary.IsFactorNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RStringVector; +import com.oracle.truffle.r.runtime.data.closures.RClosures; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; @RBuiltin(name = "names<-", kind = PRIMITIVE, parameterNames = {"x", "value"}, dispatch = INTERNAL_GENERIC, behavior = PURE) public abstract class UpdateNames extends RBuiltinNode.Arg2 { @Child private CastStringNode castStringNode; + @Child private GetLevels getFactorLevels; static { Casts casts = new Casts(UpdateNames.class); @@ -66,8 +72,17 @@ public abstract class UpdateNames extends RBuiltinNode.Arg2 { @Specialization @TruffleBoundary - protected RAbstractContainer updateNames(RAbstractContainer container, Object names, + protected RAbstractContainer updateNames(RAbstractContainer container, Object namesArg, + @Cached("new()") IsFactorNode isFactorNode, + @Cached("createBinaryProfile()") ConditionProfile isFactorProfile, @Cached("create()") GetNonSharedNode nonShared) { + Object names = namesArg; + if (isFactorProfile.profile(isFactorNode.executeIsFactor(names))) { + final RStringVector levels = getFactorLevels(names); + if (levels != null) { + names = RClosures.createFactorToVector((RAbstractIntVector) names, true, levels); + } + } Object newNames = castString(names); RAbstractContainer result = ((RAbstractContainer) nonShared.execute(container)).materialize(); if (newNames == RNull.instance) { @@ -95,6 +110,16 @@ public abstract class UpdateNames extends RBuiltinNode.Arg2 { return result; } + private RStringVector getFactorLevels(Object names) { + if (getFactorLevels == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + getFactorLevels = insert(GetLevels.create()); + } + assert names instanceof RAbstractIntVector; + final RStringVector levels = getFactorLevels.execute((RAbstractIntVector) names); + return levels; + } + @Specialization protected Object updateNames(RNull n, @SuppressWarnings("unused") RNull names) { return n; diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/variables/LocalReadVariableNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/variables/LocalReadVariableNode.java index c739628c0e..3a8f8339ef 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/variables/LocalReadVariableNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/variables/LocalReadVariableNode.java @@ -26,6 +26,7 @@ import com.oracle.truffle.api.Assumption; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.frame.Frame; +import com.oracle.truffle.api.frame.FrameDescriptor; import com.oracle.truffle.api.frame.FrameSlot; import com.oracle.truffle.api.frame.FrameSlotKind; import com.oracle.truffle.api.frame.VirtualFrame; @@ -59,6 +60,7 @@ public final class LocalReadVariableNode extends Node { @CompilationFinal private ConditionProfile isPromiseProfile; @CompilationFinal private FrameSlot frameSlot; + @CompilationFinal private FrameDescriptor frameDescriptor; @CompilationFinal private Assumption notInFrame; @CompilationFinal private Assumption containsNoActiveBindingAssumption; @@ -84,12 +86,13 @@ public final class LocalReadVariableNode extends Node { public Object execute(VirtualFrame frame, Frame variableFrame) { Frame profiledVariableFrame = frameProfile.profile(variableFrame); - if (frameSlot == null && notInFrame == null || (frameSlot != null && frameSlot.getFrameDescriptor() != variableFrame.getFrameDescriptor())) { + if (frameSlot == null && notInFrame == null || (frameSlot != null && frameDescriptor != variableFrame.getFrameDescriptor())) { CompilerDirectives.transferToInterpreterAndInvalidate(); if (identifier.toString().isEmpty()) { throw RError.error(RError.NO_CALLER, RError.Message.ZERO_LENGTH_VARIABLE); } - frameSlot = profiledVariableFrame.getFrameDescriptor().findFrameSlot(identifier); + frameDescriptor = profiledVariableFrame.getFrameDescriptor(); + frameSlot = frameDescriptor.findFrameSlot(identifier); notInFrame = frameSlot == null ? profiledVariableFrame.getFrameDescriptor().getNotInFrameAssumption(identifier) : null; } // check if the slot is missing / wrong type in current frame diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_namesassign.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_namesassign.java index e94fb1a49b..aa4fd44b0f 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_namesassign.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_namesassign.java @@ -167,4 +167,10 @@ public class TestBuiltin_namesassign extends TestBase { public void testUpdateDimnamesPairlist() { assertEval("{ l <- vector('pairlist',2); names(l)<-c('a','b'); l; }"); } + + @Test + public void testUpdateNamesByFactors() { + assertEval("{ x <- c(1,2,1,3); f <- factor(x, labels = c(\"a\",\"b\",\"c\")); names(x)<-f; x; }"); + } + } -- GitLab