diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java index 0a8d63ec7ed252318654e764ef1e79573fbcfcf3..5f3c5e577e81a4c16cb1a55aa006ad9dab0d1fb3 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java @@ -35,9 +35,18 @@ import com.oracle.truffle.r.runtime.RDeparse; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.data.RComplex; +import com.oracle.truffle.r.runtime.data.RComplexVector; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RLanguage; +import com.oracle.truffle.r.runtime.data.RLogicalVector; +import com.oracle.truffle.r.runtime.data.RRaw; +import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RString; +import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RSymbol; +import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; @@ -153,6 +162,56 @@ public abstract class BinaryBooleanNode extends RBuiltinNode.Arg2 { return RString.valueOf(RDeparse.deparse(val, RDeparse.MAX_CUTOFF, false, RDeparse.KEEPINTEGER, -1)); } + protected static boolean isOneList(Object left, Object right) { + return isRAbstractListVector(left) ^ isRAbstractListVector(right); + } + + @Specialization(guards = {"isOneList(left, right)"}) + protected Object doList(VirtualFrame frame, RAbstractVector left, RAbstractVector right, + @Cached("create()") CastTypeNode cast, + + @Cached("createRecursive()") BinaryBooleanNode recursive) { + Object recursiveLeft = left; + if (isRAbstractListVector(left)) { + recursiveLeft = castListToAtomic(left, cast, right.getRType()); + } + Object recursiveRight = right; + if (isRAbstractListVector(right)) { + recursiveRight = castListToAtomic(right, cast, left.getRType()); + } + return recursive.execute(frame, recursiveLeft, recursiveRight); + } + + @TruffleBoundary + private static Object castListToAtomic(RAbstractVector source, CastTypeNode cast, RType type) { + RVector<?> result = type.create(source.getLength(), false); + Object store = result.getInternalStore(); + for (int i = 0; i < source.getLength(); i++) { + Object value = source.getDataAtAsObject(i); + if (type == RType.Character) { + value = RDeparse.deparse(value); + ((RStringVector) result).setDataAt(store, i, (String) value); + } else { + value = cast.execute(value, type); + if (value instanceof RAbstractVector && ((RAbstractVector) value).getLength() == 1) { + value = ((RAbstractVector) value).getDataAtAsObject(0); + } + if (type == RType.Integer && value instanceof Integer) { + ((RIntVector) result).setDataAt(store, i, (int) value); + } else if (type == RType.Double && value instanceof Double) { + ((RDoubleVector) result).setDataAt(store, i, (double) value); + } else if (type == RType.Logical && value instanceof Byte) { + ((RLogicalVector) result).setDataAt(store, i, (byte) value); + } else if (type == RType.Complex && value instanceof RComplex) { + ((RComplexVector) result).setDataAt(store, i, (RComplex) value); + } else if (type == RType.Raw && value instanceof RRaw) { + ((RRawVector) result).setRawDataAt(store, i, ((RRaw) value).getValue()); + } + } + } + return result; + } + protected BinaryBooleanNode createRecursive() { return BinaryBooleanNode.create(factory); }