diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java index 628a8447547c09ec2af3c49711f97ed8d1a3f769..a286bc76d9c332926132542a321e2121b11a2ea7 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java @@ -28,9 +28,13 @@ import static com.oracle.truffle.r.runtime.RDispatch.MATH_GROUP_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; +import java.util.Arrays; + import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; @@ -49,14 +53,10 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; import com.oracle.truffle.r.runtime.ops.na.NACheck; import com.oracle.truffle.r.runtime.ops.na.NAProfile; -import java.util.Arrays; -import java.util.function.Function; public class LogFunctions { @RBuiltin(name = "log", kind = PRIMITIVE, parameterNames = {"x", "base"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) @@ -158,103 +158,47 @@ public class LogFunctions { return logb(x, base, divNode, naBase); } - @Specialization - protected RDoubleVector log(RAbstractIntVector vector, double base, + @Specialization(guards = "!isRAbstractComplexVector(vector)") + protected RDoubleVector log(RAbstractVector vector, double base, + @Cached("createClassProfile()") ValueProfile vectorProfile, + @Cached("createBinaryProfile()") ConditionProfile isNAProfile, @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, @Cached("create()") GetNamesAttributeNode getNamesNode, @Cached("create()") GetDimAttributeNode getDimsNode, @Cached("create()") NACheck xNACheck, @Cached("create()") NACheck baseNACheck) { - return log(vector, base, index -> xNACheck.convertIntToDouble(vector.getDataAt(index)), copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck); - } - - @Specialization - protected RDoubleVector log(RAbstractDoubleVector vector, double base, - @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, - @Cached("create()") GetNamesAttributeNode getNamesNode, - @Cached("create()") GetDimAttributeNode getDimsNode, - @Cached("create()") NACheck xNACheck, - @Cached("create()") NACheck baseNACheck) { - return log(vector, base, index -> checkDouble(vector.getDataAt(index), xNACheck), copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck); - } - - private static double checkDouble(double d, NACheck na) { - na.check(d); - return d; - } - - @Specialization - protected RDoubleVector log(RAbstractLogicalVector vector, double base, - @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, - @Cached("create()") GetNamesAttributeNode getNamesNode, - @Cached("create()") GetDimAttributeNode getDimsNode, - @Cached("create()") NACheck xNACheck, - @Cached("create()") NACheck baseNACheck) { - return log(vector, base, index -> xNACheck.convertLogicalToDouble(vector.getDataAt(index)), copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck); + RAbstractDoubleVector doubleVector = (RAbstractDoubleVector) vectorProfile.profile(vector).castSafe(RType.Double, isNAProfile); + return logInternal(doubleVector, base, copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck); } @Specialization protected RComplexVector log(RAbstractComplexVector vector, double base, + @Cached("createClassProfile()") ValueProfile vectorProfile, @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, @Cached("create()") GetNamesAttributeNode getNamesNode, @Cached("create()") GetDimAttributeNode getDimsNode, @Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode, @Cached("create()") NACheck xNACheck, @Cached("create()") NACheck baseNACheck) { - return log(vector, RComplex.valueOf(base, 0), copyAttrsNode, getNamesNode, getDimsNode, divNode, xNACheck, baseNACheck); + return logInternal(vectorProfile.profile(vector), RComplex.valueOf(base, 0), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck); } @Specialization - protected RAbstractComplexVector log(RAbstractIntVector vector, RComplex base, + protected RAbstractComplexVector log(RAbstractVector vector, RComplex base, + @Cached("createClassProfile()") ValueProfile vectorProfile, + @Cached("createBinaryProfile()") ConditionProfile isNAProfile, @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, @Cached("create()") GetNamesAttributeNode getNamesNode, @Cached("create()") GetDimAttributeNode getDimsNode, @Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode, @Cached("create()") NACheck xNACheck, @Cached("create()") NACheck baseNACheck) { - return log(vector, base, index -> xNACheck.convertIntToComplex(vector.getDataAt(index)), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck); + RAbstractComplexVector complexVector = (RAbstractComplexVector) vectorProfile.profile(vector).castSafe(RType.Complex, isNAProfile); + return logInternal(complexVector, base, divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck); } - @Specialization - protected RAbstractComplexVector log(RAbstractDoubleVector vector, RComplex base, - @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, - @Cached("create()") GetNamesAttributeNode getNamesNode, - @Cached("create()") GetDimAttributeNode getDimsNode, - @Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode, - @Cached("create()") NACheck xNACheck, - @Cached("create()") NACheck baseNACheck) { - return log(vector, base, index -> xNACheck.convertDoubleToComplex(vector.getDataAt(index)), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck); - } - - @Specialization - protected RAbstractComplexVector log(RAbstractLogicalVector vector, RComplex base, - @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, - @Cached("create()") GetNamesAttributeNode getNamesNode, - @Cached("create()") GetDimAttributeNode getDimsNode, - @Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode, - @Cached("create()") NACheck xNACheck, - @Cached("create()") NACheck baseNACheck) { - return log(vector, base, index -> xNACheck.convertLogicalToComplex(vector.getDataAt(index)), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck); - } - - @Specialization - protected RComplexVector log(RAbstractComplexVector vector, RComplex base, - @Cached("create()") CopyOfRegAttributesNode copyAttrsNode, - @Cached("create()") GetNamesAttributeNode getNamesNode, - @Cached("create()") GetDimAttributeNode getDimsNode, - @Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode, - @Cached("create()") NACheck xNACheck, - @Cached("create()") NACheck baseNACheck) { - return log(vector, base, index -> checkComplex(vector.getDataAt(index), xNACheck), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck); - } - - private static RComplex checkComplex(RComplex rc, NACheck xNACheck) { - xNACheck.check(rc); - return rc; - } - - private RDoubleVector log(RAbstractVector vector, double base, Function<Integer, Double> toDouble, CopyOfRegAttributesNode copyAttrsNode, GetNamesAttributeNode getNamesNode, - GetDimAttributeNode getDimsNode, NACheck xNACheck, NACheck baseNACheck) { + private RDoubleVector logInternal(RAbstractDoubleVector vector, double base, CopyOfRegAttributesNode copyAttrsNode, GetNamesAttributeNode getNamesNode, GetDimAttributeNode getDimsNode, + NACheck xNACheck, NACheck baseNACheck) { baseNACheck.enable(base); double[] resultVector = new double[vector.getLength()]; if (baseNACheck.check(base)) { @@ -266,12 +210,8 @@ public class LogFunctions { xNACheck.enable(vector); Runnable[] warningResult = new Runnable[1]; for (int i = 0; i < vector.getLength(); i++) { - double value = toDouble.apply(i); - if (!naX.isNA(value)) { - resultVector[i] = logb(value, base, warningResult); - } else { - resultVector[i] = value; - } + double value = vector.getDataAt(i); + resultVector[i] = xNACheck.check(value) ? RRuntime.DOUBLE_NA : logb(value, base, warningResult); } if (warningResult[0] != null) { warningResult[0].run(); @@ -315,8 +255,8 @@ public class LogFunctions { return result; } - private RComplexVector log(RAbstractVector vector, RComplex base, Function<Integer, RComplex> toComplex, BinaryMapArithmeticFunctionNode divNode, GetDimAttributeNode getDimsNode, - GetNamesAttributeNode getNamesNode, CopyOfRegAttributesNode copyAttrsNode, NACheck xNACheck, NACheck baseNACheck) { + private RComplexVector logInternal(RAbstractComplexVector vector, RComplex base, BinaryMapArithmeticFunctionNode divNode, GetDimAttributeNode getDimsNode, GetNamesAttributeNode getNamesNode, + CopyOfRegAttributesNode copyAttrsNode, NACheck xNACheck, NACheck baseNACheck) { baseNACheck.enable(base); double[] complexVector = new double[vector.getLength() * 2]; if (baseNACheck.check(base)) { @@ -328,13 +268,13 @@ public class LogFunctions { xNACheck.enable(vector); boolean seenNaN = false; for (int i = 0; i < vector.getLength(); i++) { - RComplex value = toComplex.apply(i); - if (!naX.isNA(value)) { + RComplex value = vector.getDataAt(i); + if (xNACheck.check(value)) { + fill(complexVector, i * 2, value); + } else { RComplex rc = logb(value, base, divNode, false); seenNaN = isNaN(rc); fill(complexVector, i * 2, rc); - } else { - fill(complexVector, i * 2, value); } } if (seenNaN) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Order.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Order.java index d8aebe6acbdd7c985e3faa8279d7be3faec83863..d13a708f1927c9e7445bdd9ff30cee4539bc24bc 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Order.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Order.java @@ -67,6 +67,7 @@ public abstract class Order extends RPrecedenceBuiltinNode { private final BranchProfile error = BranchProfile.create(); private final ConditionProfile notRemoveNAs = ConditionProfile.createBinaryProfile(); + private final ValueProfile vectorProfile = ValueProfile.createClassProfile(); /** * For use by {@link RadixSort}. @@ -79,7 +80,8 @@ public abstract class Order extends RPrecedenceBuiltinNode { return executeOrderVector1(v, naLast, dec, false); } - private RIntVector executeOrderVector1(RAbstractVector v, byte naLast, boolean dec, boolean needsStringCollation) { + private RIntVector executeOrderVector1(RAbstractVector vIn, byte naLast, boolean dec, boolean needsStringCollation) { + RAbstractVector v = vectorProfile.profile(vIn); int n = v.getLength(); reportWork(n); @@ -170,31 +172,27 @@ public abstract class Order extends RPrecedenceBuiltinNode { @Specialization(guards = {"oneVec(args)", "isFirstIntegerPrecedence(args)"}) Object orderInt(byte naLast, boolean decreasing, RArgsValuesAndNames args) { - Object[] vectors = args.getArguments(); - RAbstractIntVector v = (RAbstractIntVector) castVector(vectors[0]); + RAbstractIntVector v = (RAbstractIntVector) castVector(args.getArgument(0)); return executeOrderVector1(v, naLast, decreasing); } @Specialization(guards = {"oneVec(args)", "isFirstDoublePrecedence(args)"}) Object orderDouble(byte naLast, boolean decreasing, RArgsValuesAndNames args) { - Object[] vectors = args.getArguments(); - RAbstractDoubleVector v = (RAbstractDoubleVector) castVector(vectors[0]); + RAbstractDoubleVector v = (RAbstractDoubleVector) castVector(args.getArgument(0)); return executeOrderVector1(v, naLast, decreasing); } @Specialization(guards = {"oneVec(args)", "isFirstLogicalPrecedence(args)"}) Object orderLogical(byte naLast, boolean decreasing, RArgsValuesAndNames args, @Cached("createBinaryProfile()") ConditionProfile isNAProfile) { - Object[] vectors = args.getArguments(); - RAbstractIntVector v = (RAbstractIntVector) castVector(vectors[0]).castSafe(RType.Integer, isNAProfile); + RAbstractIntVector v = (RAbstractIntVector) castVector(args.getArgument(0)).castSafe(RType.Integer, isNAProfile); return executeOrderVector1(v, naLast, decreasing); } @Specialization(guards = {"oneVec(args)", "isFirstStringPrecedence(args)"}) Object orderString(byte naLast, boolean decreasing, RArgsValuesAndNames args, @Cached("create()") BranchProfile collationProfile) { - Object[] vectors = args.getArguments(); - RAbstractStringVector v = (RAbstractStringVector) castVector(vectors[0]); + RAbstractStringVector v = (RAbstractStringVector) castVector(args.getArgument(0)); int n = v.getLength(); boolean needsCollation = false; outer: for (int i = 0; i < n; i++) { @@ -214,8 +212,7 @@ public abstract class Order extends RPrecedenceBuiltinNode { @Specialization(guards = {"oneVec(args)", "isFirstComplexPrecedence( args)"}) Object orderComplex(byte naLast, boolean decreasing, RArgsValuesAndNames args) { - Object[] vectors = args.getArguments(); - RAbstractComplexVector v = (RAbstractComplexVector) castVector(vectors[0]); + RAbstractComplexVector v = (RAbstractComplexVector) castVector(args.getArgument(0)); return executeOrderVector1(v, naLast, decreasing); } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java index 3ef308f8b141ed5275f074ce3c285149df09ec25..d807a19e440953061b3c00bf501179c845f61908 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java @@ -93,8 +93,11 @@ public abstract class Unique extends RBuiltinNode.Arg4 { } @SuppressWarnings("unused") - @Specialization - protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, int nmax) { + @Specialization(guards = "vecIn.getClass() == vecClass") + protected RStringVector doUniqueCachedString(RAbstractStringVector vecIn, byte incomparables, byte fromLast, int nmax, + @Cached("vecIn.getClass()") Class<? extends RAbstractStringVector> vecClass) { + RAbstractStringVector vec = vecClass.cast(vecIn); + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSet<String> set = new NonRecursiveHashSet<>(vec.getLength()); String[] data = new String[vec.getLength()]; @@ -120,6 +123,11 @@ public abstract class Unique extends RBuiltinNode.Arg4 { } } + @Specialization(replaces = "doUniqueCachedString") + protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, int nmax) { + return doUniqueCachedString(vec, incomparables, fromLast, nmax, RAbstractStringVector.class); + } + // these are intended to stay private as they will go away once we figure out which external // library to use @@ -245,6 +253,7 @@ public abstract class Unique extends RBuiltinNode.Arg4 { protected RIntVector doUniqueCached(RAbstractIntVector vecIn, byte incomparables, byte fromLast, int nmax, @Cached("vecIn.getClass()") Class<? extends RAbstractIntVector> vecClass) { RAbstractIntVector vec = vecClass.cast(vecIn); + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSetInt set = new NonRecursiveHashSetInt(); int[] data = new int[16]; @@ -296,6 +305,7 @@ public abstract class Unique extends RBuiltinNode.Arg4 { @Specialization(guards = "!lengthOne(list)") @TruffleBoundary protected RList doUnique(RList list, byte incomparables, byte fromLast, int nmax) { + reportWork(list.getLength()); /* * Brute force, as manual says: Using this for lists is potentially slow, especially if the * elements are not atomic vectors (see vector) or differ only in their attributes. In the @@ -375,6 +385,7 @@ public abstract class Unique extends RBuiltinNode.Arg4 { @SuppressWarnings("unused") @Specialization protected RDoubleVector doUnique(RAbstractDoubleVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSetDouble set = new NonRecursiveHashSetDouble(vec.getLength()); double[] data = new double[vec.getLength()]; @@ -401,6 +412,7 @@ public abstract class Unique extends RBuiltinNode.Arg4 { @SuppressWarnings("unused") @Specialization protected RLogicalVector doUnique(RAbstractLogicalVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); ByteArray dataList = new ByteArray(vec.getLength()); for (int i = 0; i < vec.getLength(); i++) { byte val = vec.getDataAt(i); @@ -414,6 +426,7 @@ public abstract class Unique extends RBuiltinNode.Arg4 { @SuppressWarnings("unused") @Specialization protected RComplexVector doUnique(RAbstractComplexVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSet<RComplex> set = new NonRecursiveHashSet<>(vec.getLength()); double[] data = new double[vec.getLength() * 2]; @@ -441,6 +454,7 @@ public abstract class Unique extends RBuiltinNode.Arg4 { @SuppressWarnings("unused") @Specialization protected RRawVector doUnique(RAbstractRawVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSet<RRaw> set = new NonRecursiveHashSet<>(vec.getLength()); byte[] data = new byte[vec.getLength()]; diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastStringNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastStringNode.java index a59b45866a9f896aba4c630d879592f5af370bc3..66cc1de16106911653b8e49691134ee0778d02fb 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastStringNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastStringNode.java @@ -25,6 +25,7 @@ package com.oracle.truffle.r.nodes.unary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.runtime.RDeparse; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RLanguage; @@ -65,8 +66,10 @@ public abstract class CastStringNode extends CastStringBaseNode { } @Specialization - protected RStringVector doAbstractContainer(RAbstractContainer operand, + protected RStringVector doAbstractContainer(RAbstractContainer operandIn, + @Cached("createClassProfile()") ValueProfile operandProfile, @Cached("createBinaryProfile()") ConditionProfile isLanguageProfile) { + RAbstractContainer operand = operandProfile.profile(operandIn); String[] sdata = new String[operand.getLength()]; // conversions to character will not introduce new NAs for (int i = 0; i < operand.getLength(); i++) {