From a3df1416d2457ee24bb8eab5398cf7f2075f7819 Mon Sep 17 00:00:00 2001 From: Adam Welc <adam.welc@oracle.com> Date: Thu, 6 Nov 2014 17:36:30 +0100 Subject: [PATCH] Fixes to factors support. --- .../com/oracle/truffle/r/engine/REngine.java | 1 + .../r/nodes/builtin/base/TypeConvert.java | 2 +- .../nodes/access/array/write/CoerceVector.java | 18 +++++++++++++++++- .../array/write/UpdateArrayHelperNode.java | 5 +++++ .../r/nodes/unary/CastToContainerNode.java | 5 +++++ .../r/test/simple/TestSimpleBuiltins.java | 3 +++ 6 files changed, 32 insertions(+), 2 deletions(-) diff --git a/com.oracle.truffle.r.engine/src/com/oracle/truffle/r/engine/REngine.java b/com.oracle.truffle.r.engine/src/com/oracle/truffle/r/engine/REngine.java index ec84baa9ed..fe2eb079f4 100644 --- a/com.oracle.truffle.r.engine/src/com/oracle/truffle/r/engine/REngine.java +++ b/com.oracle.truffle.r.engine/src/com/oracle/truffle/r/engine/REngine.java @@ -282,6 +282,7 @@ public final class REngine implements RContext.Engine { ConsoleHandler ch = singleton.context.getConsoleHandler(); ch.println("Unsupported specialization in node " + use.getNode().getClass().getSimpleName() + " - supplied values: " + Arrays.asList(use.getSuppliedValues()).stream().map(v -> v.getClass().getSimpleName()).collect(Collectors.toList())); + use.printStackTrace(); return null; } catch (RecognitionException | RuntimeException e) { singleton.context.getConsoleHandler().println("Exception while parsing: " + e); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TypeConvert.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TypeConvert.java index 15e3fe0797..03ef8a58d2 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TypeConvert.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TypeConvert.java @@ -165,7 +165,7 @@ public abstract class TypeConvert extends RBuiltinNode { RIntVector res = RDataFactory.createIntVector(data, complete); res.setAttr("levels", RDataFactory.createStringVector(levelsArray, RDataFactory.COMPLETE_VECTOR)); res.setAttr("class", RDataFactory.createStringVector("factor")); - return res; + return RDataFactory.createFactor(res); } } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/CoerceVector.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/CoerceVector.java index b0185087e8..eec8bfc66d 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/CoerceVector.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/CoerceVector.java @@ -42,6 +42,15 @@ public abstract class CoerceVector extends RNode { @Child private CastIntegerNode castInteger; @Child private CastStringNode castString; @Child private CastListNode castList; + @Child private CoerceVector coerceRecursive; + + private Object coerceRecursive(VirtualFrame frame, Object value, Object vector, Object operand) { + if (coerceRecursive == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + coerceRecursive = insert(CoerceVectorFactory.create(null, null, null)); + } + return coerceRecursive.executeEvaluated(frame, value, vector, operand); + } private Object castComplex(VirtualFrame frame, Object vector) { if (castComplex == null) { @@ -309,6 +318,13 @@ public abstract class CoerceVector extends RNode { return (RList) castList(frame, vector); } + // factor value + + @Specialization + protected Object coerce(VirtualFrame frame, RFactor value, RAbstractVector vector, Object operand) { + return coerceRecursive(frame, value.getVector(), vector, operand); + } + // function vector value @Specialization @@ -338,7 +354,7 @@ public abstract class CoerceVector extends RNode { return vector; } - protected boolean isVectorList(RAbstractVector value, RAbstractVector vector) { + protected boolean isVectorList(RAbstractContainer value, RAbstractVector vector) { return vector.getElementClass() == Object.class; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/UpdateArrayHelperNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/UpdateArrayHelperNode.java index 6c714e11d7..6a1cfd0c93 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/UpdateArrayHelperNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/array/write/UpdateArrayHelperNode.java @@ -177,6 +177,11 @@ public abstract class UpdateArrayHelperNode extends RNode { return CastToContainerNodeFactory.create(child, false, false, false, true); } + @Specialization + protected Object update(VirtualFrame frame, Object v, RFactor value, int recLevel, Object positions, Object vector) { + return updateRecursive(frame, v, value.getVector(), vector, positions, recLevel); + } + @Specialization(guards = "emptyValue") protected RAbstractVector update(Object v, RAbstractVector value, int recLevel, Object[] positions, RAbstractVector vector) { if (isSubset) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastToContainerNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastToContainerNode.java index 5b03886ca5..85cd6acc90 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastToContainerNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastToContainerNode.java @@ -72,6 +72,11 @@ public abstract class CastToContainerNode extends CastNode { return dataFrame; } + @Specialization + protected RFactor cast(RFactor factor) { + return factor; + } + @Specialization protected RExpression cast(RExpression expression) { return expression; diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java index f0f6f4c2f4..90888984d5 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java @@ -3907,6 +3907,9 @@ public class TestSimpleBuiltins extends TestBase { assertEvalWarning("{ x<-factor(c(\"a\", \"b\", \"a\")); x == c(\"a\", \"b\") }"); assertEvalError("{ x<-factor(c(\"a\", \"b\", \"a\")); x > c(\"a\", \"b\") }"); assertEval("{ x<-factor(c(\"a\", \"b\", \"a\", \"c\")); x == c(\"a\", \"b\") }"); + + assertEvalWarning("{ x<-factor(c(\"c\", \"b\", \"a\", \"c\")); y<-list(1); y[1]<-x; y }"); + assertEvalWarning("{ x<-factor(c(\"c\", \"b\", \"a\", \"c\")); y<-c(1); y[1]<-x; y }"); } @Test -- GitLab