diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastSymbolNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastSymbolNode.java index 034b53109fddd47662cb8d0d4807a7398de58876..6f2ac0b11545308c569f3c75100fe53711471cc0 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastSymbolNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastSymbolNode.java @@ -24,19 +24,27 @@ package com.oracle.truffle.r.nodes.unary; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; +import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RLogicalVector; import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.RRaw; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RSymbol; +import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; public abstract class CastSymbolNode extends CastBaseNode { @@ -55,8 +63,6 @@ public abstract class CastSymbolNode extends CastBaseNode { return RType.Symbol; } - public abstract Object executeSymbol(Object o); - private String toString(Object value) { return toString.executeString(value, ToStringNode.DEFAULT_SEPARATOR); } @@ -86,6 +92,16 @@ public abstract class CastSymbolNode extends CastBaseNode { return asSymbol(toString(value)); } + @Specialization + protected RSymbol doRaw(RRaw value) { + return asSymbol(toString(value)); + } + + @Specialization + protected RSymbol doComplex(RComplex value) { + return asSymbol(toString(value)); + } + @Specialization @TruffleBoundary protected RSymbol doString(String value) { @@ -93,33 +109,44 @@ public abstract class CastSymbolNode extends CastBaseNode { CompilerDirectives.transferToInterpreter(); throw error(RError.Message.ZERO_LENGTH_VARIABLE); } - return RDataFactory.createSymbolInterned(value); - } - - @Specialization(guards = "value.getLength() > 0") - protected RSymbol doStringVector(RStringVector value) { - // Only element 0 interpreted - return doString(value.getDataAt(0)); + return asSymbol(value); } - @Specialization(guards = "value.getLength() > 0") - protected RSymbol doIntegerVector(RIntVector value) { - return doInteger(value.getDataAt(0)); - } - - @Specialization(guards = "value.getLength() > 0") - protected RSymbol doDoubleVector(RDoubleVector value) { - return doDouble(value.getDataAt(0)); + @Specialization(guards = "access.supports(vector)") + protected RSymbol doVector(RAbstractAtomicVector vector, + @Cached("createBinaryProfile()") ConditionProfile emptyProfile, + @Cached("vector.access()") VectorAccess access) { + SequentialIterator it = access.access(vector); + if (emptyProfile.profile(!access.next(it))) { + throw doEmptyVector(vector); + } + switch (access.getType()) { + case Raw: + return asSymbol(toString(RRaw.valueOf(access.getRaw(it)))); + case Logical: + return doLogical(access.getLogical(it)); + case Integer: + return doInteger(access.getInt(it)); + case Double: + return doDouble(access.getDouble(it)); + case Complex: + return doComplex(access.getComplex(it)); + case Character: + return doString(access.getString(it)); + default: + CompilerDirectives.transferToInterpreter(); + throw RInternalError.shouldNotReachHere("unexpected atomic type " + access.getType()); + } } - @Specialization(guards = "value.getLength() > 0") - protected RSymbol doLogicalVector(RLogicalVector value) { - return doLogical(value.getDataAt(0)); + @Specialization(replaces = "doVector") + protected RSymbol doVectorGeneric(RAbstractAtomicVector vector, + @Cached("createBinaryProfile()") ConditionProfile emptyProfile) { + return doVector(vector, emptyProfile, vector.slowPathAccess()); } - @Specialization(guards = "vector.getLength() == 0") @TruffleBoundary - protected RSymbol doEmptyVector(RAbstractVector vector) { + protected RError doEmptyVector(RAbstractVector vector) { if (vector instanceof RList) { throw error(RError.Message.INVALID_TYPE_LENGTH, "symbol", 0); } else { diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index eff3d08836737fed6f3054a831b4088f97ce71a8..e523da3a8135bed4770db3cb3538d013c8be3b5b 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -7908,6 +7908,14 @@ name #{ as.symbol(123) } `123` +##com.oracle.truffle.r.test.builtins.TestBuiltin_asvector.testAsSymbol# +#{ as.symbol(3+2i) } +`3+2i` + +##com.oracle.truffle.r.test.builtins.TestBuiltin_asvector.testAsSymbol# +#{ as.symbol(as.raw(16)) } +`10` + ##com.oracle.truffle.r.test.builtins.TestBuiltin_asvector.testAsSymbol# #{ as.symbol(as.symbol(123)) } `123` diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_asvector.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_asvector.java index 67ae5bdcae30836bcd3ca9d60de8b0551b7ad765..c66c7d7df15635fe005e845e5910c4a3896e5b26 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_asvector.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_asvector.java @@ -452,5 +452,7 @@ public class TestBuiltin_asvector extends TestBase { assertEval("{ as.symbol(\"name\") }"); assertEval("{ as.symbol(123) }"); assertEval("{ as.symbol(as.symbol(123)) }"); + assertEval("{ as.symbol(as.raw(16)) }"); + assertEval("{ as.symbol(3+2i) }"); } }