Skip to content
Snippets Groups Projects
Commit 94d66a1c authored by Lukas Stadler's avatar Lukas Stadler
Browse files

improve atan2 builtin

parent efa6893c
No related branches found
No related tags found
No related merge requests found
......@@ -33,8 +33,10 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RError.Message;
......@@ -413,6 +415,7 @@ public class TrigExpFunctions {
private final NACheck yNACheck = NACheck.create();
private final NACheck xNACheck = NACheck.create();
private final LoopConditionProfile profile = LoopConditionProfile.createCountingProfile();
static {
Casts casts = new Casts(Atan2.class);
......@@ -420,34 +423,19 @@ public class TrigExpFunctions {
casts.arg(1).mapIf(numericValue(), asDoubleVector());
}
private double doFunDouble(double y, double x) {
double result = x;
if (!yNACheck.check(y) && !xNACheck.check(x)) {
result = Math.atan2(y, x);
}
return result;
}
@FunctionalInterface
private interface IntDoubleFunction {
double apply(int i);
}
private RDoubleVector doFun(int length, IntDoubleFunction yFun, IntDoubleFunction xFun,
LoopConditionProfile profile) {
double[] resultVector = new double[length];
private double[] prepareArray(int length) {
reportWork(length);
profile.profileCounted(length);
for (int i = 0; profile.inject(i < length); i++) {
double y = yFun.apply(i);
double x = xFun.apply(i);
if (xNACheck.check(y) || yNACheck.check(x)) {
resultVector[i] = RRuntime.DOUBLE_NA;
} else {
resultVector[i] = Math.atan2(y, x);
}
}
return new double[length];
}
private RDoubleVector createResult(double[] resultVector) {
return RDataFactory.createDoubleVector(resultVector, xNACheck.neverSeenNA() && yNACheck.neverSeenNA());
}
......@@ -455,47 +443,76 @@ public class TrigExpFunctions {
protected double atan2(double y, double x) {
xNACheck.enable(x);
yNACheck.enable(y);
return doFunDouble(y, x);
if (yNACheck.check(y) || xNACheck.check(x)) {
return RRuntime.DOUBLE_NA;
} else {
return Math.atan2(y, x);
}
}
@Specialization
protected RDoubleVector atan2(double y, RAbstractDoubleVector x,
@Cached("createCountingProfile()") LoopConditionProfile profile) {
@Specialization(guards = "x.getLength() > 0")
protected RDoubleVector atan2(double y, RAbstractDoubleVector x) {
xNACheck.enable(x);
yNACheck.enable(y);
return doFun(x.getLength(), i -> y, i -> x.getDataAt(i), profile);
double[] array = prepareArray(x.getLength());
for (int i = 0; profile.inject(i < array.length); i++) {
double xValue = x.getDataAt(i);
if (xNACheck.check(y) || yNACheck.check(xValue)) {
array[i] = RRuntime.DOUBLE_NA;
} else {
array[i] = Math.atan2(y, xValue);
}
}
return createResult(array);
}
@Specialization
protected RDoubleVector atan2(RAbstractDoubleVector y, double x,
@Cached("createCountingProfile()") LoopConditionProfile profile) {
@Specialization(guards = "y.getLength() > 0")
protected RDoubleVector atan2(RAbstractDoubleVector y, double x) {
xNACheck.enable(x);
yNACheck.enable(y);
return doFun(y.getLength(), i -> y.getDataAt(i), i -> x, profile);
double[] array = prepareArray(y.getLength());
for (int i = 0; profile.inject(i < array.length); i++) {
double yValue = y.getDataAt(i);
if (xNACheck.check(yValue) || yNACheck.check(x)) {
array[i] = RRuntime.DOUBLE_NA;
} else {
array[i] = Math.atan2(yValue, x);
}
}
return createResult(array);
}
@Specialization
protected RDoubleVector atan2(RAbstractDoubleVector y, RAbstractDoubleVector x,
@Cached("createCountingProfile()") LoopConditionProfile profile) {
@Specialization(guards = {"y.getLength() > 0", "x.getLength() > 0"})
protected RDoubleVector atan2(RAbstractDoubleVector y, RAbstractDoubleVector x) {
int xLength = x.getLength();
int yLength = y.getLength();
xNACheck.enable(x);
yNACheck.enable(y);
return doFun(Math.max(yLength, xLength), i -> y.getDataAt(i % yLength), i -> x.getDataAt(i % xLength),
profile);
double[] array = prepareArray(Math.max(yLength, xLength));
for (int i = 0; profile.inject(i < array.length); i++) {
double yValue = y.getDataAt(i % yLength);
double xValue = x.getDataAt(i % xLength);
if (xNACheck.check(yValue) || yNACheck.check(xValue)) {
array[i] = RRuntime.DOUBLE_NA;
} else {
array[i] = Math.atan2(yValue, xValue);
}
}
return createResult(array);
}
@Specialization(guards = "y.getLength() == 0 || x.getLength() == 0")
protected RDoubleVector atan2Empty(@SuppressWarnings("unused") RAbstractDoubleVector y, @SuppressWarnings("unused") RAbstractDoubleVector x) {
return RDataFactory.createEmptyDoubleVector();
}
@Specialization(guards = {"!isDouble(x) || !isDouble(y)"})
@TruffleBoundary
@Fallback
protected Object atan2(Object x, Object y) {
CompilerDirectives.transferToInterpreter();
if (x instanceof RAbstractComplexVector || y instanceof RAbstractComplexVector) {
throw RInternalError.unimplemented("atan2 for complex values");
}
throw error(RError.Message.NON_NUMERIC_MATH);
}
protected static boolean isDouble(Object x) {
return x instanceof Double || x instanceof RAbstractDoubleVector;
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment