diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java index ec058295772e1fd4df8aa0c514bcf36fe3c7086f..407ce9faac3f31777368b02558b5ce92eb6a1b82 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java @@ -64,24 +64,27 @@ public abstract class Round extends RBuiltinNode.Arg2 { casts.arg("x").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(numericValue().or(complexValue())); // TODO: this should also accept vectors - // TODO: digits argument is rounded, not simply stripped off the decimal part - casts.arg("digits").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(numericValue().or(complexValue())).asIntegerVector().findFirst(); + casts.arg("digits").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(numericValue().or(complexValue())).asDoubleVector().findFirst(); + } + + protected static boolean isZero(double value) { + return value == 0; } @Specialization - protected double round(int x, @SuppressWarnings("unused") int digits) { + protected double round(int x, @SuppressWarnings("unused") double digits) { check.enable(x); return check.check(x) ? RRuntime.DOUBLE_NA : x; } @Specialization - protected double round(byte x, @SuppressWarnings("unused") int digits) { + protected double round(byte x, @SuppressWarnings("unused") double digits) { check.enable(x); return check.check(x) ? RRuntime.DOUBLE_NA : x; } @Specialization - protected RDoubleVector round(RAbstractLogicalVector x, @SuppressWarnings("unused") int digits) { + protected RDoubleVector round(RAbstractLogicalVector x, @SuppressWarnings("unused") double digits) { double[] data = new double[x.getLength()]; check.enable(x); for (int i = 0; i < data.length; i++) { @@ -92,7 +95,7 @@ public abstract class Round extends RBuiltinNode.Arg2 { } @Specialization - protected RDoubleVector round(RAbstractIntVector x, @SuppressWarnings("unused") int digits) { + protected RDoubleVector round(RAbstractIntVector x, @SuppressWarnings("unused") double digits) { double[] data = new double[x.getLength()]; check.enable(x); for (int i = 0; i < data.length; i++) { @@ -102,20 +105,24 @@ public abstract class Round extends RBuiltinNode.Arg2 { return RDataFactory.createDoubleVector(data, check.neverSeenNA()); } - @Specialization(guards = "digits == 0") - protected double round(double x, @SuppressWarnings("unused") int digits) { + @Specialization(guards = "isZero(digits)") + protected double round(double x, @SuppressWarnings("unused") double digits) { check.enable(x); return check.check(x) ? RRuntime.DOUBLE_NA : roundOp.op(x); } - @Specialization(guards = "digits != 0") protected double roundDigits(double x, int digits) { check.enable(x); return check.check(x) ? RRuntime.DOUBLE_NA : roundOp.opd(x, digits); } - @Specialization(guards = "digits == 0") - protected RDoubleVector round(RAbstractDoubleVector x, int digits) { + @Specialization(guards = "!isZero(digits)") + protected double roundDigits(double x, double digits) { + return roundDigits(x, (int) Math.round(digits)); + } + + @Specialization(guards = "isZero(digits)") + protected RDoubleVector round(RAbstractDoubleVector x, double digits) { double[] result = new double[x.getLength()]; check.enable(x); for (int i = 0; i < x.getLength(); i++) { @@ -127,9 +134,10 @@ public abstract class Round extends RBuiltinNode.Arg2 { return ret; } - @Specialization(guards = "digits != 0") - protected RDoubleVector roundDigits(RAbstractDoubleVector x, int digits) { + @Specialization(guards = "!isZero(dDigits)") + protected RDoubleVector roundDigits(RAbstractDoubleVector x, double dDigits) { double[] result = new double[x.getLength()]; + int digits = (int) Math.round(dDigits); check.enable(x); for (int i = 0; i < x.getLength(); i++) { double value = x.getDataAt(i); @@ -140,20 +148,24 @@ public abstract class Round extends RBuiltinNode.Arg2 { return ret; } - @Specialization(guards = "digits == 0") - protected RComplex round(RComplex x, @SuppressWarnings("unused") int digits) { + @Specialization(guards = "isZero(digits)") + protected RComplex round(RComplex x, @SuppressWarnings("unused") double digits) { check.enable(x); return check.check(x) ? RComplex.createNA() : RComplex.valueOf(roundOp.op(x.getRealPart()), roundOp.op(x.getImaginaryPart())); } - @Specialization(guards = "digits != 0") protected RComplex roundDigits(RComplex x, int digits) { check.enable(x); return check.check(x) ? RComplex.createNA() : roundOp.opd(x.getRealPart(), x.getImaginaryPart(), digits); } - @Specialization(guards = "digits == 0") - protected RComplexVector round(RAbstractComplexVector x, int digits) { + @Specialization(guards = "!isZero(digits)") + protected RComplex roundDigits(RComplex x, double digits) { + return roundDigits(x, (int) Math.round(digits)); + } + + @Specialization(guards = "isZero(digits)") + protected RComplexVector round(RAbstractComplexVector x, double digits) { double[] result = new double[x.getLength() << 1]; check.enable(x); for (int i = 0; i < x.getLength(); i++) { @@ -168,8 +180,9 @@ public abstract class Round extends RBuiltinNode.Arg2 { return ret; } - @Specialization(guards = "digits != 0") - protected RComplexVector roundDigits(RAbstractComplexVector x, int digits) { + @Specialization(guards = "!isZero(dDigits)") + protected RComplexVector roundDigits(RAbstractComplexVector x, double dDigits) { + int digits = (int) Math.round(dDigits); double[] result = new double[x.getLength() << 1]; check.enable(x); for (int i = 0; i < x.getLength(); i++) { diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java index 1101dc9b129c1b46fa0d40ad1ca5fa7bc23745c4..e5981201e3b595f02f0d3b8d867d8e62a2568657 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java @@ -48,9 +48,12 @@ public class TestBuiltin_round extends TestBase { assertEval("{ round(c(0,0.2,NaN,0.6,NA,1)) }"); assertEval("{ round(as.complex(c(0,0.2,NaN,0.6,NA,1))) }"); - // FIXME: we need to decide whether 2.8 means three digits (GnuR) or two (FastR) when - // calling round() - assertEval(Ignored.ImplementationError, "{ round(1.123456,digit=2.8) }"); + assertEval("{ round(1.123456,digit=2.8) }"); + assertEval("{ round(1.123456,digit=2.5) }"); + assertEval("{ round(1.123456,digit=2.3) }"); + assertEval("{ round(12344.126,digit=-2.8) }"); + assertEval("{ round(12344.123456,digit=-2.5) }"); + assertEval("{ round(12344.123456,digit=-2.3) }"); assertEval("{ typeof(round(42L)); }"); assertEval("{ typeof(round(TRUE)); }");