diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/RBuiltinPackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/RBuiltinPackage.java index 76727cee1decfc8c387c6d2a9915a7170c24cd40..00f6d08e2fa89cb91eeedab40c44cb94b104d5f3 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/RBuiltinPackage.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/RBuiltinPackage.java @@ -150,6 +150,6 @@ public abstract class RBuiltinPackage { parameterNames = Arrays.stream(parameterNames).map(n -> n.isEmpty() ? null : n).toArray(String[]::new); ArgumentsSignature signature = ArgumentsSignature.get(parameterNames); - putBuiltin(new RBuiltinFactory(annotation.name(), annotation.aliases(), annotation.kind(), signature, annotation.nonEvalArgs(), annotation.splitCaller(), constructor)); + putBuiltin(new RBuiltinFactory(annotation.name(), annotation.aliases(), annotation.kind(), signature, annotation.nonEvalArgs(), annotation.splitCaller(), annotation.alwaysSplit(), constructor)); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java index 0dac20fef399bdbe96102d9d9430f2572769e22a..c1fc56a3f1a9ced5359a12a927619b1bb3415dd9 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java @@ -41,8 +41,8 @@ public abstract class MatMult extends RBuiltinNode { private static final int BLOCK_SIZE = 64; - @Child private ScalarArithmeticNode mult = new ScalarArithmeticNode(BinaryArithmetic.MULTIPLY.create()); - @Child private ScalarArithmeticNode add = new ScalarArithmeticNode(BinaryArithmetic.ADD.create()); + @Child private ScalarBinaryArithmeticNode mult = new ScalarBinaryArithmeticNode(BinaryArithmetic.MULTIPLY.create()); + @Child private ScalarBinaryArithmeticNode add = new ScalarBinaryArithmeticNode(BinaryArithmetic.ADD.create()); private final BranchProfile errorProfile = BranchProfile.create(); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mod.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mod.java index bc237d14425a7f6df9212249a2df2c28b66ea531..4cc9cbc62b594f3abb9a95a7dfeb1491d4b948fa 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mod.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mod.java @@ -36,8 +36,8 @@ import com.oracle.truffle.r.runtime.ops.*; @RBuiltin(name = "Mod", kind = PRIMITIVE, parameterNames = {"z"}) public abstract class Mod extends RBuiltinNode { - @Child private ScalarArithmeticNode pow = new ScalarArithmeticNode(BinaryArithmetic.POW.create()); - @Child private ScalarArithmeticNode add = new ScalarArithmeticNode(BinaryArithmetic.ADD.create()); + @Child private ScalarBinaryArithmeticNode pow = new ScalarBinaryArithmeticNode(BinaryArithmetic.POW.create()); + @Child private ScalarBinaryArithmeticNode add = new ScalarBinaryArithmeticNode(BinaryArithmetic.ADD.create()); @Child private Sqrt sqrt = SqrtNodeGen.create(new RNode[1], null, null); @Specialization diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ArithmeticTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ArithmeticTest.java new file mode 100644 index 0000000000000000000000000000000000000000..054e7d4d00ef09c13b2b9296f21cac8ff9cdeadc --- /dev/null +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ArithmeticTest.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.test; + +import static com.oracle.truffle.r.runtime.data.RDataFactory.*; + +import org.junit.experimental.theories.*; + +import com.oracle.truffle.r.runtime.*; +import com.oracle.truffle.r.runtime.data.*; +import com.oracle.truffle.r.runtime.data.model.*; + +public class ArithmeticTest extends TestBase { + + @DataPoint public static final RScalarVector PRIMITIVE_LOGICAL = RLogical.valueOf((byte) 1); + @DataPoint public static final RScalarVector PRIMITIVE_INTEGER = RInteger.valueOf(42); + @DataPoint public static final RScalarVector PRIMITIVE_DOUBLE = RDouble.valueOf(42d); + @DataPoint public static final RScalarVector PRIMITIVE_COMPLEX = RComplex.valueOf(1.0, 1.0); + + @DataPoint public static final RAbstractVector EMPTY_LOGICAL = createEmptyLogicalVector(); + @DataPoint public static final RAbstractVector EMPTY_INTEGER = createEmptyIntVector(); + @DataPoint public static final RAbstractVector EMPTY_DOUBLE = createEmptyDoubleVector(); + @DataPoint public static final RAbstractVector EMPTY_COMPLEX = createEmptyComplexVector(); + + @DataPoint public static final RSequence SEQUENCE_INT = createIntSequence(1, 2, 10); + @DataPoint public static final RSequence SEQUENCE_DOUBLE = createDoubleSequence(1, 2, 10); + + @DataPoint public static final RAbstractVector FOUR_LOGICAL = createLogicalVector(new byte[]{1, 0, 1, 0}, true); + @DataPoint public static final RAbstractVector FOUR_INT = createIntVector(new int[]{1, 2, 3, 4}, true); + @DataPoint public static final RAbstractVector FOUR_DOUBLE = createDoubleVector(new double[]{1, 2, 3, 4}, true); + @DataPoint public static final RAbstractVector FOUR_COMPLEX = createComplexVector(new double[]{1, 1, 2, 2, 3, 3, 4, 4}, true); + + @DataPoint public static final RAbstractVector NOT_COMPLETE_LOGICAL = createLogicalVector(new byte[]{1, 0, RRuntime.LOGICAL_NA, 1}, false); + @DataPoint public static final RAbstractVector NOT_COMPLETE_INT = createIntVector(new int[]{1, 2, RInteger.NA.getValue(), 4}, false); + @DataPoint public static final RAbstractVector NOT_COMPLETE_DOUBLE = createDoubleVector(new double[]{1, 2, RDouble.NA.getValue(), 4}, false); + @DataPoint public static final RAbstractVector NOT_COMPLETE_COMPLEX = createComplexVector(new double[]{1.0d, 0.0d, RRuntime.COMPLEX_NA_REAL_PART, RRuntime.COMPLEX_NA_IMAGINARY_PART}, false); + + @DataPoint public static final RAbstractVector ONE = createIntVector(new int[]{1}, true); + @DataPoint public static final RAbstractVector TWO = createIntVector(new int[]{1, 2}, true); + @DataPoint public static final RAbstractVector THREE = createIntVector(new int[]{1, 2, 3}, true); + @DataPoint public static final RAbstractVector FIVE = createIntVector(new int[]{1, 2, 3, 4, 5}, true); + +} diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/BinaryArithmeticNodeTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/BinaryArithmeticNodeTest.java index e1187dd617449c8bb718630c6081186a3775d065..c0fd13ef8240a8c9f2fa48f7bf11e11d090aab4f 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/BinaryArithmeticNodeTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/BinaryArithmeticNodeTest.java @@ -48,38 +48,10 @@ import com.oracle.truffle.r.runtime.ops.*; * should NOT verify correctness. This is done by the integration test suite. */ @RunWith(Theories.class) -public class BinaryArithmeticNodeTest extends TestBase { +public class BinaryArithmeticNodeTest extends ArithmeticTest { @DataPoints public static final BinaryArithmeticFactory[] BINARY = ALL; - @DataPoint public static final RScalarVector PRIMITIVE_LOGICAL = RLogical.valueOf((byte) 1); - @DataPoint public static final RScalarVector PRIMITIVE_INTEGER = RInteger.valueOf(42); - @DataPoint public static final RScalarVector PRIMITIVE_DOUBLE = RDouble.valueOf(42d); - @DataPoint public static final RScalarVector PRIMITIVE_COMPLEX = RComplex.valueOf(1.0, 1.0); - - @DataPoint public static final RAbstractVector EMPTY_LOGICAL = createEmptyLogicalVector(); - @DataPoint public static final RAbstractVector EMPTY_INTEGER = createEmptyIntVector(); - @DataPoint public static final RAbstractVector EMPTY_DOUBLE = createEmptyDoubleVector(); - @DataPoint public static final RAbstractVector EMPTY_COMPLEX = createEmptyComplexVector(); - - @DataPoint public static final RSequence SEQUENCE_INT = createIntSequence(1, 2, 10); - @DataPoint public static final RSequence SEQUENCE_DOUBLE = createDoubleSequence(1, 2, 10); - - @DataPoint public static final RAbstractVector FOUR_LOGICAL = createLogicalVector(new byte[]{1, 0, 1, 0}, true); - @DataPoint public static final RAbstractVector FOUR_INT = createIntVector(new int[]{1, 2, 3, 4}, true); - @DataPoint public static final RAbstractVector FOUR_DOUBLE = createDoubleVector(new double[]{1, 2, 3, 4}, true); - @DataPoint public static final RAbstractVector FOUR_COMPLEX = createComplexVector(new double[]{1, 1, 2, 2, 3, 3, 4, 4}, true); - - @DataPoint public static final RAbstractVector NOT_COMPLETE_LOGICAL = createLogicalVector(new byte[]{1, 2, RRuntime.LOGICAL_NA, 4}, false); - @DataPoint public static final RAbstractVector NOT_COMPLETE_INT = createIntVector(new int[]{1, 2, RInteger.NA.getValue(), 4}, false); - @DataPoint public static final RAbstractVector NOT_COMPLETE_DOUBLE = createDoubleVector(new double[]{1, 2, RDouble.NA.getValue(), 4}, false); - @DataPoint public static final RAbstractVector NOT_COMPLETE_COMPLEX = createComplexVector(new double[]{1.0d, 0.0d, RRuntime.COMPLEX_NA_REAL_PART, RRuntime.COMPLEX_NA_IMAGINARY_PART}, false); - - @DataPoint public static final RAbstractVector ONE = createIntVector(new int[]{1}, true); - @DataPoint public static final RAbstractVector TWO = createIntVector(new int[]{1, 2}, true); - @DataPoint public static final RAbstractVector THREE = createIntVector(new int[]{1, 2, 3}, true); - @DataPoint public static final RAbstractVector FIVE = createIntVector(new int[]{1, 2, 3, 4, 5}, true); - @Theory public void testScalarUnboxing(BinaryArithmeticFactory factory, RScalarVector a, RAbstractVector b) { // unboxing cannot work if length is 1 diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/UnaryArithmeticNodeTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/UnaryArithmeticNodeTest.java new file mode 100644 index 0000000000000000000000000000000000000000..9401d23f1010ca2979d83587a70cb48559cd1c82 --- /dev/null +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/UnaryArithmeticNodeTest.java @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.test; + +import static com.oracle.truffle.r.nodes.test.TestUtilities.*; +import static com.oracle.truffle.r.runtime.data.RDataFactory.*; +import static com.oracle.truffle.r.runtime.ops.UnaryArithmetic.*; +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.*; +import static org.junit.Assume.*; + +import java.util.*; + +import org.junit.*; +import org.junit.experimental.theories.*; +import org.junit.runner.*; + +import com.oracle.truffle.r.nodes.test.TestUtilities.NodeHandle; +import com.oracle.truffle.r.nodes.unary.*; +import com.oracle.truffle.r.runtime.*; +import com.oracle.truffle.r.runtime.data.*; +import com.oracle.truffle.r.runtime.data.RAttributes.RAttribute; +import com.oracle.truffle.r.runtime.data.model.*; +import com.oracle.truffle.r.runtime.ops.*; + +/** + * This test verifies white box assumptions for the arithmetic node. Please note that this node + * should NOT verify correctness. This is done by the integration test suite. + */ +@RunWith(Theories.class) +public class UnaryArithmeticNodeTest extends ArithmeticTest { + + @DataPoints public static final UnaryArithmeticFactory[] UNARY = ALL; + + @Theory + public void testVectorResult(UnaryArithmeticFactory factory, RAbstractVector operand) { + assumeThat(operand, is(not(instanceOf(RScalarVector.class)))); + + Object result = executeArithmetic(factory, operand); + Assert.assertFalse(isPrimitive(result)); + assumeThat(result, is(instanceOf(RAbstractVector.class))); + RAbstractVector resultCast = (RAbstractVector) result; + + assertThat(resultCast.getLength(), is(equalTo(operand.getLength()))); + } + + @Theory + public void testSharing(UnaryArithmeticFactory factory, RAbstractVector a) { + // sharing does not work if a is a scalar vector + assumeThat(true, is(isShareable(a, a.getRType()))); + + RType resultType = getArgumentType(a); + Object sharedResult = null; + if (isShareable(a, resultType)) { + sharedResult = a; + } + + Object result = executeArithmetic(factory, a); + if (sharedResult == null) { + Assert.assertNotSame(a, result); + } else { + Assert.assertSame(sharedResult, result); + } + } + + private static boolean isShareable(RAbstractVector a, RType resultType) { + if (a.getRType() != resultType) { + // needs cast -> not shareable + return false; + } + + if (a instanceof RShareable) { + if (((RShareable) a).isTemporary()) { + return true; + } + } + return false; + } + + @Theory + public void testCompleteness(UnaryArithmeticFactory factory, RAbstractVector operand) { + Object result = executeArithmetic(factory, operand); + + boolean resultComplete = isPrimitive(result) ? true : ((RAbstractVector) result).isComplete(); + + if (operand.getLength() == 0) { + Assert.assertTrue(resultComplete); + } else { + boolean expectedComplete = operand.isComplete(); + Assert.assertEquals(expectedComplete, resultComplete); + } + } + + @Theory + public void testCopyAttributes(UnaryArithmeticFactory factory, RAbstractVector operand) { + // we have to e careful not to change mutable vectors + RAbstractVector a = operand.copy(); + if (a instanceof RShareable) { + ((RShareable) a).markNonTemporary(); + } + + RVector aMaterialized = a.copy().materialize(); + aMaterialized.setAttr("a", "a"); + assertAttributes(executeArithmetic(factory, aMaterialized.copy()), "a"); + } + + @Theory + public void testPlusFolding(RAbstractVector operand) { + assumeThat(operand, is(not(instanceOf(RScalarVector.class)))); + if (operand.getRType() == getArgumentType(operand)) { + assertFold(true, operand, PLUS); + } else { + assertFold(false, operand, PLUS); + } + } + + @Test + public void testSequenceFolding() { + assertFold(true, createIntSequence(1, 3, 10), NEGATE); + assertFold(true, createDoubleSequence(1, 3, 10), NEGATE); + assertFold(false, createIntSequence(1, 3, 10), ROUND, FLOOR, CEILING); + assertFold(false, createDoubleSequence(1, 3, 10), ROUND, FLOOR, CEILING); + } + + private static void assertAttributes(Object value, String... keys) { + if (!(value instanceof RAbstractVector)) { + Assert.assertEquals(0, keys.length); + return; + } + + RAbstractVector vector = (RAbstractVector) value; + Set<String> expectedAttributes = new HashSet<>(Arrays.asList(keys)); + + RAttributes attributes = vector.getAttributes(); + if (attributes == null) { + Assert.assertEquals(0, keys.length); + return; + } + Set<Object> foundAttributes = new HashSet<>(); + for (RAttribute attribute : attributes) { + foundAttributes.add(attribute.getName()); + foundAttributes.add(attribute.getValue()); + } + Assert.assertEquals(expectedAttributes, foundAttributes); + } + + private static RType getArgumentType(RAbstractVector operand) { + return RType.maxPrecedence(RType.Integer, operand.getRType()); + } + + private static boolean isPrimitive(Object result) { + return result instanceof Integer || result instanceof Double || result instanceof Byte || result instanceof RComplex; + } + + private void assertFold(boolean expectedFold, RAbstractVector operand, UnaryArithmeticFactory... arithmetics) { + for (int i = 0; i < arithmetics.length; i++) { + UnaryArithmeticFactory factory = arithmetics[i]; + Object result = executeArithmetic(factory, operand); + if (expectedFold) { + assertThat(String.format("expected fold %s <op> ", operand), result instanceof RSequence || result == operand); + } else { + assertThat(String.format("expected not fold %s <op> ", operand), !(result instanceof RSequence)); + } + } + } + + private NodeHandle<UnaryArithmeticNode> handle; + private UnaryArithmeticFactory currentFactory; + + @Before + public void setUp() { + handle = null; + } + + @After + public void tearDown() { + handle = null; + } + + private Object executeArithmetic(UnaryArithmeticFactory factory, Object operand) { + if (handle == null || this.currentFactory != factory) { + handle = create(factory); + this.currentFactory = factory; + } + return handle.call(operand); + } + + private static NodeHandle<UnaryArithmeticNode> create(UnaryArithmeticFactory factory) { + return createHandle(UnaryArithmeticNodeGen.create(factory, null, null), // + (node, args) -> node.execute(args[0])); + } + +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RNode.java index 40943001e18bf5d808a42a2e18a76d63fa1e20a4..4cf2aad1faf1145bf74aa407cd8fac822e4e7ec1 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RNode.java @@ -224,19 +224,18 @@ public abstract class RNode extends Node implements RSyntaxNode, RInstrumentable * @param amount an approximation of the number of operations */ protected void reportWork(long amount) { - if (CompilerDirectives.inInterpreter()) { - reportWorkInternal(amount); - } + reportWork(this, amount); } - private void reportWorkInternal(long amount) { - CompilerAsserts.neverPartOfCompilation(); - if (amount >= WORK_SCALE_FACTOR) { - int scaledAmount = (int) (amount / WORK_SCALE_FACTOR); - if (amount > 0) { - RootNode root = getRootNode(); - if (root.getCallTarget() instanceof LoopCountReceiver) { - ((LoopCountReceiver) root.getCallTarget()).reportLoopCount(scaledAmount); + public static void reportWork(Node base, long amount) { + if (CompilerDirectives.inInterpreter()) { + if (amount >= WORK_SCALE_FACTOR) { + int scaledAmount = (int) (amount / WORK_SCALE_FACTOR); + if (amount > 0) { + RootNode root = base.getRootNode(); + if (root != null && root.getCallTarget() instanceof LoopCountReceiver) { + ((LoopCountReceiver) root.getCallTarget()).reportLoopCount(scaledAmount); + } } } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RRootNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RRootNode.java index 1206fbb6064cc8f556f7081e250fe8c47bdc55ab..bbbf8fce1381193f126d217ff01a468b7edd35b1 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RRootNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RRootNode.java @@ -80,6 +80,10 @@ public abstract class RRootNode extends RootNode implements HasSignature { return formalArguments.getSignature(); } + public boolean needsSplitting() { + return false; + } + @TruffleBoundary public String getSourceCode() { SourceSection ss = getSourceSection(); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java index da98bc088189a3fb6e56dc15cf07851a2a4f4834..acb65f290855a9965bd0f0169d314142db512202 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java @@ -60,7 +60,7 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode { return BinaryArithmeticNodeGen.create(binary, unary, new RNode[]{null, null}, null, null); } - @Specialization(guards = {"cached != null", "cached.isSupported(left, right)"}) + @Specialization(limit = "3", guards = {"cached != null", "cached.isSupported(left, right)"}) protected Object doNumericVectorCached(Object left, Object right, // @Cached("createFastCached(left, right)") VectorBinaryNode cached) { return cached.apply(left, right); @@ -88,12 +88,16 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode { } protected static Class<? extends RAbstractVector> getVectorClass(Object value) { - if (value instanceof RAbstractVector && ((RAbstractVector) value).getRType().isNumeric()) { + if (isNumericVector(value)) { return ((RAbstractVector) value).getClass(); } return null; } + private static boolean isNumericVector(Object value) { + return value instanceof RAbstractIntVector || value instanceof RAbstractDoubleVector || value instanceof RAbstractComplexVector || value instanceof RAbstractLogicalVector; + } + protected static boolean isNonNumericVector(Object value) { return value instanceof RAbstractVector && !((RAbstractVector) value).getRType().isNumeric(); } @@ -173,7 +177,7 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode { resultType = RType.Double; } - return new VectorBinaryNode(new ScalarArithmeticNode(innerArithmetic), leftVector.getClass(), rightVector.getClass(), argumentType, resultType); + return new VectorBinaryNode(new ScalarBinaryArithmeticNode(innerArithmetic), leftVector.getClass(), rightVector.getClass(), argumentType, resultType); } protected static final class LRUCache { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BoxPrimitiveNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BoxPrimitiveNode.java index 111e23535406d0df95f030e353d3aa7b8ecf4190..d349662b4ee0c782e8023d738add14cece2ed76a 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BoxPrimitiveNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BoxPrimitiveNode.java @@ -32,7 +32,7 @@ import com.oracle.truffle.r.runtime.data.model.*; * analogies. */ @NodeChild("operand") -abstract class BoxPrimitiveNode extends RNode { +public abstract class BoxPrimitiveNode extends RNode { @Specialization protected static RAbstractVector doInt(int vector) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/ScalarArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/ScalarBinaryArithmeticNode.java similarity index 98% rename from com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/ScalarArithmeticNode.java rename to com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/ScalarBinaryArithmeticNode.java index 1b72bb511ff02f1dbe6a264f9ff5c768fea06644..a6bffcfdc31694a97ad49236a9573abeed67a2b6 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/ScalarArithmeticNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/ScalarBinaryArithmeticNode.java @@ -32,11 +32,11 @@ import com.oracle.truffle.r.runtime.ops.na.*; /** * */ -public final class ScalarArithmeticNode extends ScalarBinaryNode { +public final class ScalarBinaryArithmeticNode extends ScalarBinaryNode { @Child private BinaryArithmetic arithmetic; - public ScalarArithmeticNode(BinaryArithmetic arithmetic) { + public ScalarBinaryArithmeticNode(BinaryArithmetic arithmetic) { this.arithmetic = arithmetic; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/VectorBinaryNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/VectorBinaryNode.java index b498cdc764d1285731bf266685ab96f8ff0073f4..bf108bbcfc573521e51cad2e119677bcbe76c81d 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/VectorBinaryNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/VectorBinaryNode.java @@ -25,6 +25,7 @@ package com.oracle.truffle.r.nodes.binary; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.nodes.*; import com.oracle.truffle.api.utilities.*; +import com.oracle.truffle.r.nodes.*; import com.oracle.truffle.r.nodes.profile.*; import com.oracle.truffle.r.runtime.*; import com.oracle.truffle.r.runtime.data.*; @@ -35,13 +36,13 @@ import com.oracle.truffle.r.runtime.data.model.*; * ensures that attributes and dimensions are properly migrated from the source vectors to the * result vectors. It also implements sharing of temporary vectors as result vector. Internally it * uses a {@link ScalarBinaryNode} to abstract one scalar operation invocation on the vector and - * {@link IndexedVectorIterationNode} to abstract the iteration over two arrays with potentially - * differing length. The {@link ScalarBinaryNode} instance can be passed from the outside in order - * to enable the use for different scalar operations like logic and arithmetic operations. + * {@link VectorMapBinaryNode} to abstract the iteration over two arrays with potentially differing + * length. The {@link ScalarBinaryNode} instance can be passed from the outside in order to enable + * the use for different scalar operations like logic and arithmetic operations. */ final class VectorBinaryNode extends Node { - @Child private IndexedVectorIterationNode vectorNode; + @Child private VectorMapBinaryNode vectorNode; @Child private ScalarBinaryNode scalarNode; // profiles @@ -67,7 +68,7 @@ final class VectorBinaryNode extends Node { this.scalarNode = scalarNode; this.leftClass = leftclass; this.rightClass = rightClass; - this.vectorNode = IndexedVectorIterationNode.create(resultType, argumentType); + this.vectorNode = VectorMapBinaryNode.create(resultType, argumentType); this.scalarTypes = RScalarVector.class.isAssignableFrom(leftclass) && RScalarVector.class.isAssignableFrom(rightClass); boolean leftVectorImpl = RVector.class.isAssignableFrom(leftclass); boolean rightVectorImpl = RVector.class.isAssignableFrom(rightClass); @@ -174,7 +175,8 @@ final class VectorBinaryNode extends Node { target = scalarNode.tryFoldConstantTime(leftCast, leftLength, rightCast, rightLength); } if (target == null) { - target = createOrShareVector(leftLength, left, rightLength, right); + int maxLength = Math.max(leftLength, rightLength); + target = createOrShareVector(leftLength, left, rightLength, right, maxLength); Object store; if (target instanceof RAccessibleStore) { store = ((RAccessibleStore<?>) target).getInternalStore(); @@ -182,6 +184,7 @@ final class VectorBinaryNode extends Node { throw RInternalError.shouldNotReachHere(); } vectorNode.apply(scalarNode, store, leftCast, leftLength, rightCast, rightLength); + RNode.reportWork(this, maxLength); } if (mayContainMetadata) { target = handleMetadata(target, left, leftLength, right, rightLength); @@ -190,8 +193,8 @@ final class VectorBinaryNode extends Node { return target; } - private RAbstractVector createOrShareVector(int leftLength, RAbstractVector left, int rightLength, RAbstractVector right) { - int maxLength = Math.max(leftLength, rightLength); + private RAbstractVector createOrShareVector(int leftLength, RAbstractVector left, int rightLength, RAbstractVector right, int maxLength) { + RType resultType = getResultType(); if (mayShareLeft && left.getRType() == resultType && shareLeft.profile(leftLength == maxLength && ((RShareable) left).isTemporary())) { return left; @@ -199,24 +202,7 @@ final class VectorBinaryNode extends Node { if (mayShareRight && right.getRType() == resultType && shareRight.profile(rightLength == maxLength && ((RShareable) right).isTemporary())) { return right; } - return createResult(maxLength); - } - - private RAbstractVector createResult(int length) { - switch (getResultType()) { - case Logical: - return RDataFactory.createLogicalVector(length); - case Integer: - return RDataFactory.createIntVector(length); - case Double: - return RDataFactory.createDoubleVector(length); - case Complex: - return RDataFactory.createComplexVector(length); - case Character: - return RDataFactory.createStringVector(length); - default: - throw RInternalError.shouldNotReachHere(); - } + return getResultType().create(maxLength); } private RType getArgumentType() { @@ -243,6 +229,7 @@ final class VectorBinaryNode extends Node { @TruffleBoundary private void copyAttributesInternal(RVector result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength) { + // TODO this method needs its own specializing node if (leftLength == rightLength) { if (result != right) { result.copyRegAttributesFrom(right); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/IndexedVectorIterationNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/VectorMapBinaryNode.java similarity index 97% rename from com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/IndexedVectorIterationNode.java rename to com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/VectorMapBinaryNode.java index 29c1df3985fe49fdf7c6a69464e8038ad1ea39b5..b51e7ee9d047ea3c8b4222b4c64a0e485fa0e60b 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/IndexedVectorIterationNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/VectorMapBinaryNode.java @@ -35,7 +35,7 @@ import com.oracle.truffle.r.runtime.data.model.*; * implementation also ensures that there is no internal boxing inside the loop. */ @SuppressWarnings("unused") -abstract class IndexedVectorIterationNode extends Node { +abstract class VectorMapBinaryNode extends Node { private static final MapBinaryIndexedAction<byte[], RAbstractLogicalVector> LOGICAL = // (arithmetic, result, resultIndex, left, leftIndex, right, rightIndex) -> { @@ -73,7 +73,7 @@ abstract class IndexedVectorIterationNode extends Node { private final RType resultType; @SuppressWarnings("unchecked") - protected IndexedVectorIterationNode(RType resultType, RType argumentType) { + protected VectorMapBinaryNode(RType resultType, RType argumentType) { this.indexedAction = (MapBinaryIndexedAction<Object, RAbstractVector>) createIndexedAction(resultType, argumentType); this.argumentType = argumentType; this.resultType = resultType; @@ -87,8 +87,8 @@ abstract class IndexedVectorIterationNode extends Node { return resultType; } - public static IndexedVectorIterationNode create(RType resultType, RType argumentType) { - return IndexedVectorIterationNodeGen.create(resultType, argumentType); + public static VectorMapBinaryNode create(RType resultType, RType argumentType) { + return VectorMapBinaryNodeGen.create(resultType, argumentType); } private static MapBinaryIndexedAction<? extends Object, ? extends RAbstractVector> createIndexedAction(RType resultType, RType argumentType) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinFactory.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinFactory.java index dc83aee35b5f0127bf22088a9aa1f5bf91c3fae2..a3a60bb5d2659b60ed7dc9264b9bb34be3880117 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinFactory.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinFactory.java @@ -37,8 +37,8 @@ public final class RBuiltinFactory extends RBuiltinDescriptor { private final NodeGenFactory constructor; - public RBuiltinFactory(String name, String[] aliases, RBuiltinKind kind, ArgumentsSignature signature, int[] nonEvalArgs, boolean splitCaller, NodeGenFactory constructor) { - super(name, aliases, kind, signature, nonEvalArgs, splitCaller); + public RBuiltinFactory(String name, String[] aliases, RBuiltinKind kind, ArgumentsSignature signature, int[] nonEvalArgs, boolean splitCaller, boolean alwaysSplit, NodeGenFactory constructor) { + super(name, aliases, kind, signature, nonEvalArgs, splitCaller, alwaysSplit); this.constructor = constructor; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinRootNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinRootNode.java index d584962420c5ad58a47f432e95325e37733a86bf..79ea3776376cadffc251ed2fd5c7e8f2b7b68865 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinRootNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/RBuiltinRootNode.java @@ -46,6 +46,11 @@ public final class RBuiltinRootNode extends RRootNode { return builtin; } + @Override + public boolean needsSplitting() { + return builtin.getBuiltin().isAlwaysSplit(); + } + public RCallNode inline(InlinedArguments args) { assert builtin.getSuppliedSignature() != null : this; return builtin.inline(args); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/CallMatcherNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/CallMatcherNode.java index 8bc81e9e944c2cca4f69593bf739b2c48cb38552..ac67bab11bd0bd34a37f52ff2476acca7bc543b0 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/CallMatcherNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/CallMatcherNode.java @@ -126,6 +126,10 @@ public abstract class CallMatcherNode extends Node { return replace(new CallMatcherGenericNode(forNextMethod, argsAreEvaluated)).execute(frame, suppliedSignature, suppliedArguments, function, s3Args); } else { CallMatcherCachedNode cachedNode = replace(specialize(suppliedSignature, suppliedArguments, function, getEncapsulatingSourceSection(), forNextMethod, argsAreEvaluated, this)); + // for splitting if necessary + if (RCallNode.needsSplitting(function)) { + cachedNode.call.cloneCallTarget(); + } return cachedNode.execute(frame, suppliedSignature, suppliedArguments, function, s3Args); } } @@ -159,7 +163,6 @@ public abstract class CallMatcherNode extends Node { this.preparePermutation = preparePermutation; this.permutation = permutation; this.next = next; - this.call = Truffle.getRuntime().createDirectCallNode(cachedCallTarget); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java index fb79dfab9f17f686ad6863f21a3210c193d8890d..fe9aa69a054468b708f67178ea9eda7954529629 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java @@ -137,6 +137,7 @@ public final class FunctionDefinitionNode extends RRootNode implements RSyntaxNo } + @Override public boolean needsSplitting() { return needsSplitting; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java index 78d44e43468b306cfb42e4493d4d474615b3960a..24dfe4630ef967efb705eddf4340c77432ea254c 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java @@ -80,13 +80,13 @@ import com.oracle.truffle.r.runtime.gnur.*; * U = {@link UninitializedCallNode}: Forms the uninitialized end of the function PIC * D = {@link DispatchedCallNode}: Function fixed, no varargs * G = {@link GenericCallNode}: Function arbitrary, no varargs (generic case) - * + * * UV = {@link UninitializedCallNode} with varargs, * UVC = {@link UninitializedVarArgsCacheCallNode} with varargs, for varargs cache * DV = {@link DispatchedVarArgsCallNode}: Function fixed, with cached varargs * DGV = {@link DispatchedGenericVarArgsCallNode}: Function fixed, with arbitrary varargs (generic case) * GV = {@link GenericVarArgsCallNode}: Function arbitrary, with arbitrary varargs (generic case) - * + * * (RB = {@link RBuiltinNode}: individual functions that are builtins are represented by this node * which is not aware of caching). Due to {@link CachedCallNode} (see below) this is transparent to * the cache and just behaves like a D/DGV) @@ -99,11 +99,11 @@ import com.oracle.truffle.r.runtime.gnur.*; * non varargs, max depth: * | * D-D-D-U - * + * * no varargs, generic (if max depth is exceeded): * | * D-D-D-D-G - * + * * varargs: * | * DV-DV-UV <- function call target identity level cache @@ -111,7 +111,7 @@ import com.oracle.truffle.r.runtime.gnur.*; * DV * | * UVC <- varargs signature level cache - * + * * varargs, max varargs depth exceeded: * | * DV-DV-UV @@ -123,7 +123,7 @@ import com.oracle.truffle.r.runtime.gnur.*; * DV * | * DGV - * + * * varargs, max function depth exceeded: * | * DV-DV-DV-DV-GV @@ -324,12 +324,12 @@ public abstract class RCallNode extends RNode { return (RCallNode) parent; } - private static boolean needsSplitting(RFunction function) { + public static boolean needsSplitting(RFunction function) { RootNode root = function.getRootNode(); if (function.containsDispatch()) { return true; - } else if (root instanceof FunctionDefinitionNode) { - return ((FunctionDefinitionNode) root).needsSplitting(); + } else if (root instanceof RRootNode) { + return ((RRootNode) root).needsSplitting(); } return false; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ScalarUnaryArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ScalarUnaryArithmeticNode.java new file mode 100644 index 0000000000000000000000000000000000000000..6ae59fc52e71b9dc1a261a8b10d6132a8bc22101 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ScalarUnaryArithmeticNode.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.unary; + +import com.oracle.truffle.r.runtime.*; +import com.oracle.truffle.r.runtime.data.*; +import com.oracle.truffle.r.runtime.data.model.*; +import com.oracle.truffle.r.runtime.ops.*; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic.Negate; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic.Plus; + +public class ScalarUnaryArithmeticNode extends ScalarUnaryNode { + + @Child private UnaryArithmetic arithmetic; + + public ScalarUnaryArithmeticNode(UnaryArithmetic arithmetic) { + this.arithmetic = arithmetic; + } + + @Override + public RAbstractVector tryFoldConstantTime(RAbstractVector operand, int operandLength) { + if (arithmetic instanceof Plus) { + return operand; + } else if (arithmetic instanceof Negate && operand instanceof RSequence) { + if (operand instanceof RIntSequence) { + int start = ((RIntSequence) operand).getStart(); + int stride = ((RIntSequence) operand).getStride(); + return RDataFactory.createIntSequence(applyInteger(start), applyInteger(stride), operandLength); + } else if (operand instanceof RDoubleSequence) { + double start = ((RDoubleSequence) operand).getStart(); + double stride = ((RDoubleSequence) operand).getStride(); + return RDataFactory.createDoubleSequence(applyDouble(start), applyDouble(stride), operandLength); + } + } + return null; + } + + @Override + public boolean mayFoldConstantTime(Class<? extends RAbstractVector> operandClass) { + if (arithmetic instanceof Plus) { + return true; + } else if (arithmetic instanceof Negate && RSequence.class.isAssignableFrom(operandClass)) { + return true; + } + return false; + } + + @Override + public final double applyDouble(double operand) { + if (operandNACheck.check(operand)) { + return RRuntime.DOUBLE_NA; + } + return arithmetic.op(operand); + } + + @Override + public final RComplex applyComplex(RComplex operand) { + if (operandNACheck.check(operand)) { + return RComplex.NA; + } + return arithmetic.op(operand.getRealPart(), operand.getImaginaryPart()); + } + + @Override + public final int applyInteger(int operand) { + if (operandNACheck.check(operand)) { + return RRuntime.INT_NA; + } + return arithmetic.op(operand); + } + +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ScalarUnaryNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ScalarUnaryNode.java new file mode 100644 index 0000000000000000000000000000000000000000..6a8b711afe4176ade26d4088ad0d397038323ef3 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ScalarUnaryNode.java @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.unary; + +import com.oracle.truffle.api.nodes.*; +import com.oracle.truffle.r.runtime.*; +import com.oracle.truffle.r.runtime.data.*; +import com.oracle.truffle.r.runtime.data.model.*; +import com.oracle.truffle.r.runtime.ops.na.*; + +/** + * Encapsulates an abstract scalar unary operation to be executed multiple times when calculating + * vectors or primitive types. + */ + +@SuppressWarnings("unused") +public abstract class ScalarUnaryNode extends Node { + + protected final NACheck operandNACheck = new NACheck(); + + public boolean mayFoldConstantTime(Class<? extends RAbstractVector> operandClass) { + return false; + } + + public RAbstractVector tryFoldConstantTime(RAbstractVector operand, int operandLength) { + return null; + } + + /** + * Enables all NA checks for the given input vectors. + */ + public final void enable(RAbstractVector operand) { + operandNACheck.enable(operand); + } + + /** + * Returns <code>true</code> if there was never a <code>null</code> value encountered when using + * this node. Make you have enabled the NA check properly using {@link #enable(RAbstractVector)} + * before relying on this method. + */ + public final boolean isComplete() { + return operandNACheck.neverSeenNA(); + } + + public byte applyLogical(byte operand) { + throw RInternalError.shouldNotReachHere(); + } + + public int applyInteger(int operand) { + throw RInternalError.shouldNotReachHere(); + } + + public double applyDouble(double operand) { + throw RInternalError.shouldNotReachHere(); + } + + public String applyCharacter(String operand) { + throw RInternalError.shouldNotReachHere(); + } + + public RComplex applyComplex(RComplex operand) { + throw RInternalError.shouldNotReachHere(); + } + +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java index c1b15784f4ef3e92f59d5fa8c57bdae0118e2203..f7e0eac708b66d905350bc7c88d472ee59675ca1 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java @@ -22,159 +22,92 @@ */ package com.oracle.truffle.r.nodes.unary; +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.*; +import com.oracle.truffle.r.nodes.*; +import com.oracle.truffle.r.nodes.binary.*; import com.oracle.truffle.r.runtime.*; import com.oracle.truffle.r.runtime.RError.Message; -import com.oracle.truffle.r.runtime.data.*; -import com.oracle.truffle.r.runtime.data.closures.*; import com.oracle.truffle.r.runtime.data.model.*; import com.oracle.truffle.r.runtime.ops.*; -import com.oracle.truffle.r.runtime.ops.na.*; public abstract class UnaryArithmeticNode extends UnaryNode { - private final UnaryArithmetic arithmetic; - - private final NAProfile naProfile = NAProfile.create(); - - private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); - + protected final UnaryArithmeticFactory unary; private final Message error; public UnaryArithmeticNode(UnaryArithmeticFactory factory, Message error) { - this.arithmetic = factory.create(); + this.unary = factory; this.error = error; } - public UnaryArithmeticNode(UnaryArithmeticNode prev) { - this.arithmetic = prev.arithmetic; - this.error = prev.error; + @CreateCast("operand") + protected static RNode createBoxNode(RNode operand) { + return BoxPrimitiveNodeGen.create(operand); } public abstract Object execute(Object operand); - @Specialization - protected int doInt(int operand) { - return naProfile.isNA(operand) ? RRuntime.INT_NA : arithmetic.op(operand); - } - - @Specialization - protected double doDouble(double operand) { - return naProfile.isNA(operand) ? RRuntime.DOUBLE_NA : arithmetic.op(operand); - } - - @Specialization - protected RComplex doComplex(RComplex operand) { - return naProfile.isNA(operand) ? RRuntime.createComplexNA() : arithmetic.op(operand.getRealPart(), operand.getImaginaryPart()); + @Specialization(guards = {"cachedNode != null", "cachedNode.isSupported(operand)"}) + protected Object doCached(Object operand, @Cached("createCachedFast(operand)") VectorUnaryNode cachedNode) { + return cachedNode.apply(operand); } - @Specialization - protected int doLogical(byte operand) { - return naProfile.isNA(operand) ? RRuntime.INT_NA : arithmetic.op(operand); - } - - private void copyAttributes(RVector ret, RAbstractVector v) { - ret.copyRegAttributesFrom(v); - ret.setDimensions(v.getDimensions()); - ret.copyNamesFrom(attrProfiles, v); - } - - @Specialization(guards = "operands.isComplete()") - protected RDoubleVector doDoubleVector(RAbstractDoubleVector operands) { - double[] res = new double[operands.getLength()]; - for (int i = 0; i < operands.getLength(); i++) { - res[i] = arithmetic.op(operands.getDataAt(i)); + protected VectorUnaryNode createCachedFast(Object operand) { + if (isNumericVector(operand)) { + return createCached(unary.create(), operand); } - RDoubleVector ret = RDataFactory.createDoubleVector(res, RDataFactory.COMPLETE_VECTOR); - copyAttributes(ret, operands); - return ret; + return null; } - @Specialization(guards = "!operands.isComplete()") - protected RDoubleVector doDoubleVectorNA(RAbstractDoubleVector operands) { - double[] res = new double[operands.getLength()]; - for (int i = 0; i < operands.getLength(); i++) { - if (RRuntime.isNA(operands.getDataAt(i))) { - res[i] = RRuntime.DOUBLE_NA; - } else { - res[i] = arithmetic.op(operands.getDataAt(i)); + protected static VectorUnaryNode createCached(UnaryArithmetic arithmetic, Object operand) { + if (operand instanceof RAbstractVector) { + RAbstractVector castOperand = (RAbstractVector) operand; + RType operandType = castOperand.getRType(); + if (operandType.isNumeric()) { + RType type = RType.maxPrecedence(operandType, RType.Integer); + return new VectorUnaryNode(new ScalarUnaryArithmeticNode(arithmetic), castOperand.getClass(), type, type); } } - RDoubleVector ret = RDataFactory.createDoubleVector(res, false); - copyAttributes(ret, operands); - return ret; + return null; } - @Specialization(guards = "operands.isComplete()") - protected RComplexVector doComplexVector(RAbstractComplexVector operands) { - double[] res = new double[operands.getLength() * 2]; - for (int i = 0; i < operands.getLength(); i++) { - RComplex r = arithmetic.op(operands.getDataAt(i).getRealPart(), operands.getDataAt(i).getImaginaryPart()); - res[2 * i] = r.getRealPart(); - res[2 * i + 1] = r.getImaginaryPart(); - } - RComplexVector ret = RDataFactory.createComplexVector(res, RDataFactory.COMPLETE_VECTOR); - copyAttributes(ret, operands); - return ret; + protected static boolean isNumericVector(Object value) { + return value instanceof RAbstractIntVector || value instanceof RAbstractDoubleVector || value instanceof RAbstractComplexVector || value instanceof RAbstractLogicalVector; } - @Specialization(guards = "!operands.isComplete()") - protected RComplexVector doComplexVectorNA(RAbstractComplexVector operands) { - double[] res = new double[operands.getLength() * 2]; - for (int i = 0; i < operands.getLength(); i++) { - if (RRuntime.isNA(operands.getDataAt(i))) { - res[2 * i] = RRuntime.DOUBLE_NA; - res[2 * i + 1] = 0.0; - } else { - RComplex r = arithmetic.op(operands.getDataAt(i).getRealPart(), operands.getDataAt(i).getImaginaryPart()); - res[2 * i] = r.getRealPart(); - res[2 * i + 1] = r.getImaginaryPart(); - } - } - RComplexVector ret = RDataFactory.createComplexVector(res, false); - copyAttributes(ret, operands); - return ret; + @Specialization(contains = "doCached", guards = {"isNumericVector(operand)"}) + @TruffleBoundary + protected Object doGeneric(Object operand, // + @Cached("unary.create()") UnaryArithmetic arithmetic, // + @Cached("new(createCached(arithmetic, operand))") LRUCache lru) { + RAbstractVector operandVector = (RAbstractVector) operand; + return lru.get(arithmetic, operandVector).apply(operandVector); } - @Specialization(guards = "operands.isComplete()") - protected RIntVector doIntVector(RAbstractIntVector operands) { - int[] res = new int[operands.getLength()]; - for (int i = 0; i < operands.getLength(); i++) { - res[i] = arithmetic.op(operands.getDataAt(i)); - } - RIntVector ret = RDataFactory.createIntVector(res, RDataFactory.COMPLETE_VECTOR); - copyAttributes(ret, operands); - return ret; + @Fallback + protected Object invalidArgType(@SuppressWarnings("unused") Object operand) { + throw RError.error(getEncapsulatingSourceSection(), error); } - @Specialization(guards = "!operands.isComplete()") - protected RIntVector doIntVectorNA(RAbstractIntVector operands) { - int[] res = new int[operands.getLength()]; - for (int i = 0; i < operands.getLength(); i++) { - if (RRuntime.isNA(operands.getDataAt(i))) { - res[i] = RRuntime.INT_NA; - } else { - res[i] = arithmetic.op(operands.getDataAt(i)); + protected static final class LRUCache { + + private VectorUnaryNode cached; + + public VectorUnaryNode get(UnaryArithmetic arithmetic, RAbstractVector operand) { + if (!cached.isSupported(operand)) { + cached = createCached(arithmetic, operand); + cached.adoptChildren(); } + return cached; } - RIntVector ret = RDataFactory.createIntVector(res, false); - copyAttributes(ret, operands); - return ret; - } - @Specialization(guards = "operands.isComplete()") - protected RIntVector doLogicalVector(RAbstractLogicalVector operands) { - return doIntVector(RClosures.createLogicalToIntVector(operands)); - } - - @Specialization(guards = "!operands.isComplete()") - protected RIntVector doLogicalVectorNA(RAbstractLogicalVector operands) { - return doIntVectorNA(RClosures.createLogicalToIntVector(operands)); - } + public LRUCache(VectorUnaryNode cachedOperation) { + this.cached = cachedOperation; + // force adoption of the children for use in Truffle boundary -> vector might rewrite. + this.cached.adoptChildren(); + } - @Fallback - protected Object invalidArgType(@SuppressWarnings("unused") Object operand) { - throw RError.error(getEncapsulatingSourceSection(), error); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/VectorMapUnaryNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/VectorMapUnaryNode.java new file mode 100644 index 0000000000000000000000000000000000000000..e48e14a66ecdac4cd1046c8db65585852002a804 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/VectorMapUnaryNode.java @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.unary; + +import com.oracle.truffle.api.dsl.*; +import com.oracle.truffle.api.nodes.*; +import com.oracle.truffle.r.runtime.*; +import com.oracle.truffle.r.runtime.data.*; +import com.oracle.truffle.r.runtime.data.model.*; + +@SuppressWarnings("unused") +public abstract class VectorMapUnaryNode extends Node { + + private static final MapIndexedAction<byte[], RAbstractLogicalVector> LOGICAL = // + (arithmetic, result, resultIndex, left, leftIndex) -> { + result[resultIndex] = arithmetic.applyLogical(left.getDataAt(leftIndex)); + }; + + private static final MapIndexedAction<int[], RAbstractIntVector> INTEGER = // + (arithmetic, result, resultIndex, left, leftIndex) -> { + result[resultIndex] = arithmetic.applyInteger(left.getDataAt(leftIndex)); + }; + + private static final MapIndexedAction<double[], RAbstractDoubleVector> DOUBLE = // + (arithmetic, result, resultIndex, left, leftIndex) -> { + result[resultIndex] = arithmetic.applyDouble(left.getDataAt(leftIndex)); + }; + + private static final MapIndexedAction<double[], RAbstractComplexVector> COMPLEX = // + (arithmetic, result, resultIndex, left, leftIndex) -> { + RComplex value = arithmetic.applyComplex(left.getDataAt(leftIndex)); + result[resultIndex << 1] = value.getRealPart(); + result[(resultIndex << 1) + 1] = value.getImaginaryPart(); + }; + private static final MapIndexedAction<String[], RAbstractStringVector> CHARACTER = // + (arithmetic, result, resultIndex, left, leftIndex) -> { + result[resultIndex] = arithmetic.applyCharacter(left.getDataAt(leftIndex)); + }; + + private final MapIndexedAction<Object, RAbstractVector> indexedAction; + private final RType argumentType; + private final RType resultType; + + @SuppressWarnings("unchecked") + protected VectorMapUnaryNode(RType resultType, RType argumentType) { + this.indexedAction = (MapIndexedAction<Object, RAbstractVector>) createIndexedAction(resultType, argumentType); + this.argumentType = argumentType; + this.resultType = resultType; + } + + public RType getArgumentType() { + return argumentType; + } + + public RType getResultType() { + return resultType; + } + + public static VectorMapUnaryNode create(RType resultType, RType argumentType) { + return VectorMapUnaryNodeGen.create(resultType, argumentType); + } + + private static MapIndexedAction<? extends Object, ? extends RAbstractVector> createIndexedAction(RType resultType, RType argumentType) { + switch (argumentType) { + case Logical: + return LOGICAL; + case Integer: + switch (resultType) { + case Integer: + return INTEGER; + case Double: + return DOUBLE; + default: + throw RInternalError.shouldNotReachHere(); + } + case Double: + return DOUBLE; + case Complex: + return COMPLEX; + case Character: + return CHARACTER; + default: + throw RInternalError.shouldNotReachHere(); + } + } + + public final void apply(ScalarUnaryNode scalarAction, Object store, RAbstractVector operand, int operandLength) { + assert operand.getLength() == operandLength; + assert operand.getRType() == argumentType; + assert isStoreCompatible(store, resultType, operandLength); + + executeInternal(scalarAction, store, operand, operandLength); + } + + protected static boolean isStoreCompatible(Object store, RType resultType, int operandLength) { + switch (resultType) { + case Logical: + assert store instanceof byte[] && ((byte[]) store).length == operandLength; + return true; + case Integer: + assert store instanceof int[] && ((int[]) store).length == operandLength; + return true; + case Double: + assert store instanceof double[] && ((double[]) store).length == operandLength; + return true; + case Complex: + assert store instanceof double[] && ((double[]) store).length >> 1 == operandLength; + return true; + case Character: + assert store instanceof String[] && ((String[]) store).length == operandLength; + return true; + default: + throw RInternalError.shouldNotReachHere(); + } + } + + protected abstract void executeInternal(ScalarUnaryNode node, Object store, RAbstractVector operand, int operandLength); + + @Specialization(guards = {"operandLength == 1"}) + protected void doScalar(ScalarUnaryNode node, Object store, RAbstractVector operand, int operandLength) { + indexedAction.perform(node, store, 0, operand, 0); + } + + @Specialization(contains = "doScalar") + protected void doScalarVector(ScalarUnaryNode node, Object store, RAbstractVector operand, int operandLength) { + for (int i = 0; i < operandLength; ++i) { + indexedAction.perform(node, store, i, operand, i); + } + } + + private interface MapIndexedAction<A, V extends RAbstractVector> { + + void perform(ScalarUnaryNode action, A store, int resultIndex, V operand, int operandIndex); + + } + +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/VectorUnaryNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/VectorUnaryNode.java new file mode 100644 index 0000000000000000000000000000000000000000..2ac8822dda333cae8f8e8f5943cfa886dea781ac --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/VectorUnaryNode.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.unary; + +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.nodes.*; +import com.oracle.truffle.api.utilities.*; +import com.oracle.truffle.r.nodes.*; +import com.oracle.truffle.r.nodes.profile.*; +import com.oracle.truffle.r.runtime.*; +import com.oracle.truffle.r.runtime.data.*; +import com.oracle.truffle.r.runtime.data.model.*; + +class VectorUnaryNode extends Node { + + @Child private ScalarUnaryNode scalarNode; + @Child private VectorMapUnaryNode vectorNode; + + // profiles + private final Class<? extends RAbstractVector> operandClass; + private final VectorLengthProfile operandLengthProfile = VectorLengthProfile.create(); + private final BranchProfile hasAttributesProfile; + private final RAttributeProfiles attrProfiles; + private final ConditionProfile shareOperand; + + // compile-time optimization flags + private final boolean scalarType; + private final boolean mayContainMetadata; + private final boolean mayFoldConstantTime; + private final boolean mayShareOperand; + + VectorUnaryNode(ScalarUnaryNode scalarNode, Class<? extends RAbstractVector> operandClass, RType argumentType, RType resultType) { + this.scalarNode = scalarNode; + this.vectorNode = VectorMapUnaryNode.create(resultType, argumentType); + this.operandClass = operandClass; + this.scalarType = RScalarVector.class.isAssignableFrom(operandClass); + boolean operandVector = RVector.class.isAssignableFrom(operandClass); + this.mayContainMetadata = operandVector; + this.mayFoldConstantTime = scalarNode.mayFoldConstantTime(operandClass); + this.mayShareOperand = operandVector; + + // lazily create profiles only if needed to avoid unnecessary allocations + this.shareOperand = operandVector ? ConditionProfile.createBinaryProfile() : null; + this.attrProfiles = mayContainMetadata ? RAttributeProfiles.create() : null; + this.hasAttributesProfile = mayContainMetadata ? BranchProfile.create() : null; + } + + public RType getArgumentType() { + return vectorNode.getArgumentType(); + } + + public RType getResultType() { + return vectorNode.getResultType(); + } + + public boolean isSupported(Object operand) { + return operand.getClass() == operandClass; + } + + public Object apply(Object originalOperand) { + assert isSupported(originalOperand); + RAbstractVector operand = operandClass.cast(originalOperand); + + int operandLength = operandLengthProfile.profile(operand.getLength()); + RAbstractVector operandCast = operand.castSafe(getArgumentType()); + + scalarNode.enable(operandCast); + if (scalarType) { + assert operand.getLength() == 1; + return scalarOperation(operandCast); + } else { + return vectorOperation(operand, operandCast, operandLength); + } + } + + private Object scalarOperation(RAbstractVector operand) { + switch (getArgumentType()) { + case Logical: + return scalarNode.applyLogical(((RAbstractLogicalVector) operand).getDataAt(0)); + case Integer: + return scalarNode.applyInteger(((RAbstractIntVector) operand).getDataAt(0)); + case Double: + return scalarNode.applyDouble(((RAbstractDoubleVector) operand).getDataAt(0)); + case Complex: + return scalarNode.applyComplex(((RAbstractComplexVector) operand).getDataAt(0)); + default: + throw RInternalError.shouldNotReachHere(); + } + } + + private Object vectorOperation(RAbstractVector operand, RAbstractVector operandCast, int operandLength) { + RAbstractVector target = null; + if (mayFoldConstantTime) { + target = scalarNode.tryFoldConstantTime(operandCast, operandLength); + } + if (target == null) { + target = createOrShareVector(operandLength, operand); + Object store; + if (target instanceof RAccessibleStore) { + store = ((RAccessibleStore<?>) target).getInternalStore(); + } else { + throw RInternalError.shouldNotReachHere(); + } + vectorNode.apply(scalarNode, store, operandCast, operandLength); + RNode.reportWork(this, operandLength); + } + if (mayContainMetadata) { + target = handleMetadata(target, operand); + } + target.setComplete(scalarNode.isComplete()); + return target; + } + + private RAbstractVector createOrShareVector(int operandLength, RAbstractVector operand) { + RType resultType = getResultType(); + if (mayShareOperand && operand.getRType() == resultType && shareOperand.profile(((RShareable) operand).isTemporary())) { + return operand; + } + return resultType.create(operandLength); + } + + private RAbstractVector handleMetadata(RAbstractVector target, RAbstractVector operand) { + RAbstractVector result = target; + if (containsMetadata(operand) && operand != target) { + hasAttributesProfile.enter(); + result = result.materialize(); + copyAttributesInternal((RVector) result, operand); + } + return result; + } + + private boolean containsMetadata(RAbstractVector vector) { + return vector instanceof RVector && (vector.hasDimensions() || vector.getAttributes() != null || vector.getNames(attrProfiles) != null || vector.getDimNames(attrProfiles) != null); + } + + @TruffleBoundary + private void copyAttributesInternal(RVector result, RAbstractVector attributeSource) { + result.copyRegAttributesFrom(attributeSource); + result.setDimensions(attributeSource.getDimensions()); + result.copyNamesFrom(attrProfiles, attributeSource); + } +} diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RBuiltin.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RBuiltin.java index 6b4c5f868c93457cd269a54b1db0409173f028e5..32dc331cbfe8fa52c6123df52035da74f18a556a 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RBuiltin.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RBuiltin.java @@ -63,4 +63,6 @@ public @interface RBuiltin { * sites. <code>name</code> indicates the builtin name defined in {@link #name()}. */ boolean splitCaller() default false; + + boolean alwaysSplit() default false; } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RType.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RType.java index e0611fc44dc76a45fc6c7ec986be438e3ff4d972..450cab23c083a9b8814e88e633ac5964f55e3434 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RType.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RType.java @@ -133,4 +133,21 @@ public enum RType { } } + public RAbstractVector create(int length) { + switch (this) { + case Logical: + return RDataFactory.createLogicalVector(length); + case Integer: + return RDataFactory.createIntVector(length); + case Double: + return RDataFactory.createDoubleVector(length); + case Complex: + return RDataFactory.createComplexVector(length); + case Character: + return RDataFactory.createStringVector(length); + default: + throw RInternalError.shouldNotReachHere(); + } + } + } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RBuiltinDescriptor.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RBuiltinDescriptor.java index 7439b0d90b57e34dd837e19e4907d1ffbaa3952d..82cf50b315427213399de66bc2cae35a8026ba66 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RBuiltinDescriptor.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RBuiltinDescriptor.java @@ -34,16 +34,18 @@ public abstract class RBuiltinDescriptor { private final ArgumentsSignature signature; private final int[] nonEvalArgs; private final boolean splitCaller; + private final boolean alwaysSplit; private final RGroupGenerics group; @CompilationFinal private final boolean[] evaluatesArgument; - public RBuiltinDescriptor(String name, String[] aliases, RBuiltinKind kind, ArgumentsSignature signature, int[] nonEvalArgs, boolean splitCaller) { + public RBuiltinDescriptor(String name, String[] aliases, RBuiltinKind kind, ArgumentsSignature signature, int[] nonEvalArgs, boolean splitCaller, boolean alwaysSplit) { this.name = name; this.aliases = aliases; this.kind = kind; this.signature = signature; this.nonEvalArgs = nonEvalArgs; this.splitCaller = splitCaller; + this.alwaysSplit = alwaysSplit; this.group = RGroupGenerics.getGroup(name); evaluatesArgument = new boolean[signature.getLength()]; @@ -74,6 +76,10 @@ public abstract class RBuiltinDescriptor { return nonEvalArgs; } + public boolean isAlwaysSplit() { + return alwaysSplit; + } + public boolean isSplitCaller() { return splitCaller; } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java index 23931f3a316ff535878f1202579951dace102a9d..9b750b36ae1455267ba56f2f763999a5cf17ede7 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java @@ -28,31 +28,31 @@ public abstract class BinaryArithmetic extends Operation { /* Fake RBuiltins to unify the binary operations */ - @RBuiltin(name = "+", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "+", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class AddBuiltin { } - @RBuiltin(name = "-", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "-", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class SubtractBuiltin { } - @RBuiltin(name = "/", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "/", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class DivBuiltin { } - @RBuiltin(name = "%/%", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "%/%", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class IntegerDivBuiltin { } - @RBuiltin(name = "%%", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "%%", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class ModBuiltin { } - @RBuiltin(name = "*", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "*", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class MultiplyBuiltin { } - @RBuiltin(name = "^", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "^", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class PowBuiltin { } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryCompare.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryCompare.java index 51cdb19edf48abf7f758ab631bea9fc27e096362..30fa1f6682759a51232440a6a0fdb199a203fd77 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryCompare.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryCompare.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2014, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2015, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -28,27 +28,27 @@ import com.oracle.truffle.r.runtime.data.*; public abstract class BinaryCompare extends BooleanOperation { /* Fake RBuiltins to unify the compare operations */ - @RBuiltin(name = "==", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "==", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class EqualBuiltin { } - @RBuiltin(name = "!=", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "!=", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class NotEqualBuiltin { } - @RBuiltin(name = ">=", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = ">=", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class GreaterEqualBuiltin { } - @RBuiltin(name = ">", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = ">", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class GreaterBuiltin { } - @RBuiltin(name = "<=", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "<=", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class LessEqualBuiltin { } - @RBuiltin(name = "<", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}) + @RBuiltin(name = "<", kind = RBuiltinKind.PRIMITIVE, parameterNames = {"", ""}, alwaysSplit = true) public static class LessBuiltin { } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java index 055c521d06c41854d6f1c055212d9bdf38d378c4..ec1747be3a1a5dc0af5cafb5897b2fe19e95fcec 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java @@ -7,7 +7,7 @@ * Copyright (c) 1998, Ross Ihaka * Copyright (c) 1998-2012, The R Core Team * Copyright (c) 2005, The R Foundation - * Copyright (c) 2013, 2014, Oracle and/or its affiliates + * Copyright (c) 2013, 2015, Oracle and/or its affiliates * * All rights reserved. */ @@ -25,6 +25,7 @@ public abstract class UnaryArithmetic extends Operation { public static final UnaryArithmeticFactory FLOOR = Floor::new; public static final UnaryArithmeticFactory CEILING = Ceiling::new; public static final UnaryArithmeticFactory PLUS = Plus::new; + public static final UnaryArithmeticFactory[] ALL = new UnaryArithmeticFactory[]{NEGATE, ROUND, FLOOR, CEILING, PLUS}; public UnaryArithmetic() { super(false, false);