diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sqrt.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sqrt.java index 9322f2a98d20247ff678c5f08660d7038b20be1a..0ca7c92bcc9ab92f5694e997ff28d5214a6f65e5 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sqrt.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sqrt.java @@ -22,86 +22,58 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.runtime.RBuiltinKind.*; +import static com.oracle.truffle.r.runtime.RBuiltinKind.PRIMITIVE; -import com.oracle.truffle.api.dsl.*; -import com.oracle.truffle.api.profiles.ConditionProfile; -import com.oracle.truffle.r.nodes.builtin.*; -import com.oracle.truffle.r.runtime.*; -import com.oracle.truffle.r.runtime.data.*; -import com.oracle.truffle.r.runtime.ops.na.*; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNode; +import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNodeGen; +import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.unary.UnaryArithmeticNode; +import com.oracle.truffle.r.nodes.unary.UnaryArithmeticNodeGen; +import com.oracle.truffle.r.runtime.RBuiltin; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.data.RComplex; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; +import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; @RBuiltin(name = "sqrt", kind = PRIMITIVE, parameterNames = {"x"}) public abstract class Sqrt extends RBuiltinNode { - private final NACheck na = NACheck.create(); - private final ConditionProfile naConditionProfile = ConditionProfile.createBinaryProfile(); - private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); - private final NullProfile dimensionsProfile = NullProfile.create(); + + public static final UnaryArithmeticFactory SQRT = SqrtArithmetic::new; + + @Child private BoxPrimitiveNode boxPrimitive = BoxPrimitiveNodeGen.create(); + @Child private UnaryArithmeticNode sqrtNode = UnaryArithmeticNodeGen.create(SQRT, RError.Message.NON_NUMERIC_MATH, RType.Double); @Specialization - public double sqrt(double x) { - controlVisibility(); - na.enable(x); - if (naConditionProfile.profile(na.check(x))) { - return RRuntime.DOUBLE_NA; - } else { - return Math.sqrt(x); - } + protected Object sqrt(Object value) { + return sqrtNode.execute(boxPrimitive.execute(value)); } - @Specialization - protected double sqrt(int x) { - controlVisibility(); - na.enable(x); - if (naConditionProfile.profile(na.check(x))) { - return RRuntime.DOUBLE_NA; - } else { - return Math.sqrt(x); + public static class SqrtArithmetic extends UnaryArithmetic { + + @Override + public int op(byte op) { + return op; } - } - @Specialization - protected double sqrt(byte x) { - controlVisibility(); - // sqrt for logical values: TRUE -> 1, FALSE -> 0, NA -> NA - na.enable(x); - if (naConditionProfile.profile(na.check(x))) { - return RRuntime.DOUBLE_NA; - } else { - return x; + @Override + public int op(int op) { + return (int) Math.sqrt(op); } - } - @Specialization - protected RDoubleVector sqrt(RIntSequence xs) { - controlVisibility(); - double[] res = new double[xs.getLength()]; - int current = xs.getStart(); - for (int i = 0; i < xs.getLength(); i++) { - double sqrt = Math.sqrt(current); - res[i] = sqrt; - current += xs.getStride(); + @Override + public double op(double op) { + return Math.sqrt(op); } - RDoubleVector result = RDataFactory.createDoubleVector(res, na.neverSeenNA(), dimensionsProfile.profile(xs.getDimensions()), xs.getNames(attrProfiles)); - result.copyRegAttributesFrom(xs); - return result; - } - @Specialization - protected RDoubleVector sqrt(RDoubleVector xs) { - controlVisibility(); - double[] res = new double[xs.getLength()]; - na.enable(xs); - for (int i = 0; i < xs.getLength(); i++) { - if (naConditionProfile.profile(na.check(xs.getDataAt(i)))) { - res[i] = RRuntime.DOUBLE_NA; - } else { - res[i] = Math.sqrt(xs.getDataAt(i)); - } + @Override + public RComplex op(double re, double im) { + double r = Math.sqrt(Math.sqrt(re * re + im * im)); + double theta = Math.atan2(im, re) / 2; + return RComplex.valueOf(r * Math.cos(theta), r * Math.sin(theta)); } - RDoubleVector result = RDataFactory.createDoubleVector(res, na.neverSeenNA(), dimensionsProfile.profile(xs.getDimensions()), xs.getNames(attrProfiles)); - result.copyRegAttributesFrom(xs); - return result; + } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sqrt.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sqrt.java index aa7dc5fb6ddd111639af4ba26c73ad0fa2bab110..8fcc967e2b6eafb625e27baf4756588585743c07 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sqrt.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sqrt.java @@ -29,7 +29,7 @@ public class TestBuiltin_sqrt extends TestBase { @Test public void testsqrt3() { - assertEval(Ignored.Unknown, "argv <- list(-17+0i);sqrt(argv[[1]]);"); + assertEval("argv <- list(-17+0i);sqrt(argv[[1]]);"); } @Test @@ -54,13 +54,12 @@ public class TestBuiltin_sqrt extends TestBase { @Test public void testsqrt8() { - assertEval(Ignored.Unknown, - "argv <- list(structure(c(660, 543, 711, 500, 410, 309, 546, 351, 269, 203, 370, 193, 181, 117, 243, 136, 117, 87, 154, 84), .Dim = 4:5, .Dimnames = list(c('Rural Male', 'Rural Female', 'Urban Male', 'Urban Female'), c('70-74', '65-69', '60-64', '55-59', '50-54'))));sqrt(argv[[1]]);"); + assertEval("argv <- list(structure(c(660, 543, 711, 500, 410, 309, 546, 351, 269, 203, 370, 193, 181, 117, 243, 136, 117, 87, 154, 84), .Dim = 4:5, .Dimnames = list(c('Rural Male', 'Rural Female', 'Urban Male', 'Urban Female'), c('70-74', '65-69', '60-64', '55-59', '50-54'))));sqrt(argv[[1]]);"); } @Test public void testsqrt9() { - assertEval(Ignored.Unknown, "argv <- list(c(6L, 5L, 4L, 3L, 2L, 1L, 0L, NA, NA, NA, NA));sqrt(argv[[1]]);"); + assertEval("argv <- list(c(6L, 5L, 4L, 3L, 2L, 1L, 0L, NA, NA, NA, NA));sqrt(argv[[1]]);"); } @Test @@ -70,7 +69,12 @@ public class TestBuiltin_sqrt extends TestBase { @Test public void testsqrt11() { - assertEval(Ignored.Unknown, "argv <- list(0+1i);sqrt(argv[[1]]);"); + assertEval("argv <- list(0+1i);sqrt(argv[[1]]);"); + } + + @Test + public void testsqrt12() { + assertEval("argv <- list(c(TRUE, FALSE));sqrt(argv[[1]]);"); } @Test