From 5bf80e3eca095b918f6e8976217ff35d4c19a272 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Thu, 12 Oct 2017 18:18:26 +0200
Subject: [PATCH] handle non-integer digit numbers in "round"

---
 .../truffle/r/nodes/builtin/base/Round.java   | 53 ++++++++++++-------
 .../r/test/builtins/TestBuiltin_round.java    |  9 ++--
 2 files changed, 39 insertions(+), 23 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java
index ec05829577..407ce9faac 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java
@@ -64,24 +64,27 @@ public abstract class Round extends RBuiltinNode.Arg2 {
         casts.arg("x").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(numericValue().or(complexValue()));
 
         // TODO: this should also accept vectors
-        // TODO: digits argument is rounded, not simply stripped off the decimal part
-        casts.arg("digits").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(numericValue().or(complexValue())).asIntegerVector().findFirst();
+        casts.arg("digits").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(numericValue().or(complexValue())).asDoubleVector().findFirst();
+    }
+
+    protected static boolean isZero(double value) {
+        return value == 0;
     }
 
     @Specialization
-    protected double round(int x, @SuppressWarnings("unused") int digits) {
+    protected double round(int x, @SuppressWarnings("unused") double digits) {
         check.enable(x);
         return check.check(x) ? RRuntime.DOUBLE_NA : x;
     }
 
     @Specialization
-    protected double round(byte x, @SuppressWarnings("unused") int digits) {
+    protected double round(byte x, @SuppressWarnings("unused") double digits) {
         check.enable(x);
         return check.check(x) ? RRuntime.DOUBLE_NA : x;
     }
 
     @Specialization
-    protected RDoubleVector round(RAbstractLogicalVector x, @SuppressWarnings("unused") int digits) {
+    protected RDoubleVector round(RAbstractLogicalVector x, @SuppressWarnings("unused") double digits) {
         double[] data = new double[x.getLength()];
         check.enable(x);
         for (int i = 0; i < data.length; i++) {
@@ -92,7 +95,7 @@ public abstract class Round extends RBuiltinNode.Arg2 {
     }
 
     @Specialization
-    protected RDoubleVector round(RAbstractIntVector x, @SuppressWarnings("unused") int digits) {
+    protected RDoubleVector round(RAbstractIntVector x, @SuppressWarnings("unused") double digits) {
         double[] data = new double[x.getLength()];
         check.enable(x);
         for (int i = 0; i < data.length; i++) {
@@ -102,20 +105,24 @@ public abstract class Round extends RBuiltinNode.Arg2 {
         return RDataFactory.createDoubleVector(data, check.neverSeenNA());
     }
 
-    @Specialization(guards = "digits == 0")
-    protected double round(double x, @SuppressWarnings("unused") int digits) {
+    @Specialization(guards = "isZero(digits)")
+    protected double round(double x, @SuppressWarnings("unused") double digits) {
         check.enable(x);
         return check.check(x) ? RRuntime.DOUBLE_NA : roundOp.op(x);
     }
 
-    @Specialization(guards = "digits != 0")
     protected double roundDigits(double x, int digits) {
         check.enable(x);
         return check.check(x) ? RRuntime.DOUBLE_NA : roundOp.opd(x, digits);
     }
 
-    @Specialization(guards = "digits == 0")
-    protected RDoubleVector round(RAbstractDoubleVector x, int digits) {
+    @Specialization(guards = "!isZero(digits)")
+    protected double roundDigits(double x, double digits) {
+        return roundDigits(x, (int) Math.round(digits));
+    }
+
+    @Specialization(guards = "isZero(digits)")
+    protected RDoubleVector round(RAbstractDoubleVector x, double digits) {
         double[] result = new double[x.getLength()];
         check.enable(x);
         for (int i = 0; i < x.getLength(); i++) {
@@ -127,9 +134,10 @@ public abstract class Round extends RBuiltinNode.Arg2 {
         return ret;
     }
 
-    @Specialization(guards = "digits != 0")
-    protected RDoubleVector roundDigits(RAbstractDoubleVector x, int digits) {
+    @Specialization(guards = "!isZero(dDigits)")
+    protected RDoubleVector roundDigits(RAbstractDoubleVector x, double dDigits) {
         double[] result = new double[x.getLength()];
+        int digits = (int) Math.round(dDigits);
         check.enable(x);
         for (int i = 0; i < x.getLength(); i++) {
             double value = x.getDataAt(i);
@@ -140,20 +148,24 @@ public abstract class Round extends RBuiltinNode.Arg2 {
         return ret;
     }
 
-    @Specialization(guards = "digits == 0")
-    protected RComplex round(RComplex x, @SuppressWarnings("unused") int digits) {
+    @Specialization(guards = "isZero(digits)")
+    protected RComplex round(RComplex x, @SuppressWarnings("unused") double digits) {
         check.enable(x);
         return check.check(x) ? RComplex.createNA() : RComplex.valueOf(roundOp.op(x.getRealPart()), roundOp.op(x.getImaginaryPart()));
     }
 
-    @Specialization(guards = "digits != 0")
     protected RComplex roundDigits(RComplex x, int digits) {
         check.enable(x);
         return check.check(x) ? RComplex.createNA() : roundOp.opd(x.getRealPart(), x.getImaginaryPart(), digits);
     }
 
-    @Specialization(guards = "digits == 0")
-    protected RComplexVector round(RAbstractComplexVector x, int digits) {
+    @Specialization(guards = "!isZero(digits)")
+    protected RComplex roundDigits(RComplex x, double digits) {
+        return roundDigits(x, (int) Math.round(digits));
+    }
+
+    @Specialization(guards = "isZero(digits)")
+    protected RComplexVector round(RAbstractComplexVector x, double digits) {
         double[] result = new double[x.getLength() << 1];
         check.enable(x);
         for (int i = 0; i < x.getLength(); i++) {
@@ -168,8 +180,9 @@ public abstract class Round extends RBuiltinNode.Arg2 {
         return ret;
     }
 
-    @Specialization(guards = "digits != 0")
-    protected RComplexVector roundDigits(RAbstractComplexVector x, int digits) {
+    @Specialization(guards = "!isZero(dDigits)")
+    protected RComplexVector roundDigits(RAbstractComplexVector x, double dDigits) {
+        int digits = (int) Math.round(dDigits);
         double[] result = new double[x.getLength() << 1];
         check.enable(x);
         for (int i = 0; i < x.getLength(); i++) {
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java
index 1101dc9b12..e5981201e3 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_round.java
@@ -48,9 +48,12 @@ public class TestBuiltin_round extends TestBase {
         assertEval("{ round(c(0,0.2,NaN,0.6,NA,1)) }");
         assertEval("{ round(as.complex(c(0,0.2,NaN,0.6,NA,1))) }");
 
-        // FIXME: we need to decide whether 2.8 means three digits (GnuR) or two (FastR) when
-        // calling round()
-        assertEval(Ignored.ImplementationError, "{ round(1.123456,digit=2.8) }");
+        assertEval("{ round(1.123456,digit=2.8) }");
+        assertEval("{ round(1.123456,digit=2.5) }");
+        assertEval("{ round(1.123456,digit=2.3) }");
+        assertEval("{ round(12344.126,digit=-2.8) }");
+        assertEval("{ round(12344.123456,digit=-2.5) }");
+        assertEval("{ round(12344.123456,digit=-2.3) }");
 
         assertEval("{ typeof(round(42L)); }");
         assertEval("{ typeof(round(TRUE)); }");
-- 
GitLab