Skip to content
Snippets Groups Projects
Commit 6adc8369 authored by Zbynek Slajchrt's avatar Zbynek Slajchrt
Browse files

A few issues fixed when enabling randomForest

parent 9af0f33c
No related branches found
No related tags found
No related merge requests found
......@@ -31,8 +31,6 @@ import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.InlineCacheNode;
import com.oracle.truffle.r.nodes.InlineCacheNodeGen;
import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode;
import com.oracle.truffle.r.nodes.function.PromiseHelperNode;
import com.oracle.truffle.r.nodes.function.RCallerHelper;
......@@ -58,7 +56,6 @@ public abstract class RForceAndCallNode extends RBaseNode {
return RForceAndCallNodeGen.create();
}
@Child private InlineCacheNode closureEvalNode = InlineCacheNodeGen.create(10);
@Child private PromiseHelperNode promiseHelper = new PromiseHelperNode();
public abstract Object executeObject(Object e, Object f, int n, Object env);
......
......@@ -31,23 +31,29 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.helpers.RFactorNodes.GetLevels;
import com.oracle.truffle.r.nodes.unary.CastStringNode;
import com.oracle.truffle.r.nodes.unary.CastStringNodeGen;
import com.oracle.truffle.r.nodes.unary.GetNonSharedNode;
import com.oracle.truffle.r.nodes.unary.IsFactorNode;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.closures.RClosures;
import com.oracle.truffle.r.runtime.data.model.RAbstractContainer;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
@RBuiltin(name = "names<-", kind = PRIMITIVE, parameterNames = {"x", "value"}, dispatch = INTERNAL_GENERIC, behavior = PURE)
public abstract class UpdateNames extends RBuiltinNode.Arg2 {
@Child private CastStringNode castStringNode;
@Child private GetLevels getFactorLevels;
static {
Casts casts = new Casts(UpdateNames.class);
......@@ -66,8 +72,17 @@ public abstract class UpdateNames extends RBuiltinNode.Arg2 {
@Specialization
@TruffleBoundary
protected RAbstractContainer updateNames(RAbstractContainer container, Object names,
protected RAbstractContainer updateNames(RAbstractContainer container, Object namesArg,
@Cached("new()") IsFactorNode isFactorNode,
@Cached("createBinaryProfile()") ConditionProfile isFactorProfile,
@Cached("create()") GetNonSharedNode nonShared) {
Object names = namesArg;
if (isFactorProfile.profile(isFactorNode.executeIsFactor(names))) {
final RStringVector levels = getFactorLevels(names);
if (levels != null) {
names = RClosures.createFactorToVector((RAbstractIntVector) names, true, levels);
}
}
Object newNames = castString(names);
RAbstractContainer result = ((RAbstractContainer) nonShared.execute(container)).materialize();
if (newNames == RNull.instance) {
......@@ -95,6 +110,16 @@ public abstract class UpdateNames extends RBuiltinNode.Arg2 {
return result;
}
private RStringVector getFactorLevels(Object names) {
if (getFactorLevels == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
getFactorLevels = insert(GetLevels.create());
}
assert names instanceof RAbstractIntVector;
final RStringVector levels = getFactorLevels.execute((RAbstractIntVector) names);
return levels;
}
@Specialization
protected Object updateNames(RNull n, @SuppressWarnings("unused") RNull names) {
return n;
......
......@@ -26,6 +26,7 @@ import com.oracle.truffle.api.Assumption;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.frame.Frame;
import com.oracle.truffle.api.frame.FrameDescriptor;
import com.oracle.truffle.api.frame.FrameSlot;
import com.oracle.truffle.api.frame.FrameSlotKind;
import com.oracle.truffle.api.frame.VirtualFrame;
......@@ -59,6 +60,7 @@ public final class LocalReadVariableNode extends Node {
@CompilationFinal private ConditionProfile isPromiseProfile;
@CompilationFinal private FrameSlot frameSlot;
@CompilationFinal private FrameDescriptor frameDescriptor;
@CompilationFinal private Assumption notInFrame;
@CompilationFinal private Assumption containsNoActiveBindingAssumption;
......@@ -84,12 +86,13 @@ public final class LocalReadVariableNode extends Node {
public Object execute(VirtualFrame frame, Frame variableFrame) {
Frame profiledVariableFrame = frameProfile.profile(variableFrame);
if (frameSlot == null && notInFrame == null || (frameSlot != null && frameSlot.getFrameDescriptor() != variableFrame.getFrameDescriptor())) {
if (frameSlot == null && notInFrame == null || (frameSlot != null && frameDescriptor != variableFrame.getFrameDescriptor())) {
CompilerDirectives.transferToInterpreterAndInvalidate();
if (identifier.toString().isEmpty()) {
throw RError.error(RError.NO_CALLER, RError.Message.ZERO_LENGTH_VARIABLE);
}
frameSlot = profiledVariableFrame.getFrameDescriptor().findFrameSlot(identifier);
frameDescriptor = profiledVariableFrame.getFrameDescriptor();
frameSlot = frameDescriptor.findFrameSlot(identifier);
notInFrame = frameSlot == null ? profiledVariableFrame.getFrameDescriptor().getNotInFrameAssumption(identifier) : null;
}
// check if the slot is missing / wrong type in current frame
......
......@@ -167,4 +167,10 @@ public class TestBuiltin_namesassign extends TestBase {
public void testUpdateDimnamesPairlist() {
assertEval("{ l <- vector('pairlist',2); names(l)<-c('a','b'); l; }");
}
@Test
public void testUpdateNamesByFactors() {
assertEval("{ x <- c(1,2,1,3); f <- factor(x, labels = c(\"a\",\"b\",\"c\")); names(x)<-f; x; }");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment