Skip to content
Snippets Groups Projects
Commit d7aa001e authored by Stepan Sindelar's avatar Stepan Sindelar
Browse files

[GR-2798] Implement findInterval builtin.

PullRequest: fastr/1259
parents b14f9cc2 c07e2ccc
No related branches found
No related tags found
No related merge requests found
......@@ -494,6 +494,7 @@ public class BasePackage extends RBuiltinPackage {
add(FileFunctions.ListFiles.class, FileFunctionsFactory.ListFilesNodeGen::create);
add(FileFunctions.ListDirs.class, FileFunctionsFactory.ListDirsNodeGen::create);
add(FileFunctions.Unlink.class, FileFunctionsFactory.UnlinkNodeGen::create);
add(FindInterval.class, FindIntervalNodeGen::create);
add(ForceAndCall.class, ForceAndCallNodeGen::create);
add(Formals.class, FormalsNodeGen::create);
add(Format.class, FormatNodeGen::create);
......
/*
* This material is distributed under the GNU General Public License
* Version 2. You may review the terms of this license at
* http://www.gnu.org/licenses/gpl-2.0.html
*
* Copyright (c) 2002--2016, The R Core Team
* Copyright (c) 2017, Oracle and/or its affiliates
*
* All rights reserved.
*/
package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_ARITHMETIC;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
/**
* Note: R wrapper function {@code findInterval(x,vec,...)} has first two arguments swapped.
*/
@RBuiltin(name = "findInterval", kind = INTERNAL, parameterNames = {"xt", "x", "rightmost.closed", "all.inside", "left.open"}, behavior = PURE_ARITHMETIC)
public abstract class FindInterval extends RBuiltinNode.Arg5 {
static {
Casts casts = new Casts(FindInterval.class);
casts.arg("xt").mustBe(doubleValue(), Message.INVALID_INPUT).asDoubleVector();
casts.arg("x").mustBe(doubleValue(), Message.INVALID_INPUT).asDoubleVector();
casts.arg("rightmost.closed").asLogicalVector().findFirst().mustNotBeNA().map(toBoolean());
casts.arg("all.inside").asLogicalVector().findFirst().map(toBoolean());
casts.arg("left.open").asLogicalVector().findFirst().map(toBoolean());
}
@Specialization(guards = {"xtAccess.supports(xt)", "xAccess.supports(x)"})
RAbstractIntVector doFindInterval(RAbstractDoubleVector xt, RAbstractDoubleVector x, boolean right, boolean inside, boolean leftOpen,
@Cached("createEqualityProfile()") ValueProfile leftOpenProfile,
@Cached("create(xt)") VectorAccess xtAccess,
@Cached("create(xt)") VectorAccess xAccess,
@Cached("create()") VectorFactory vectorFactory) {
boolean leftOpenProfiled = leftOpenProfile.profile(leftOpen);
try (SequentialIterator xIter = xAccess.access(x)) {
int[] result = new int[xAccess.getLength(xIter)];
int i = 0;
boolean complete = true;
int previous = 1;
while (xAccess.next(xIter)) {
if (xAccess.isNA(xIter)) {
previous = RRuntime.INT_NA;
complete = false;
} else {
try (RandomIterator xtIter = xtAccess.randomAccess(xt)) {
previous = findInterval2(xtAccess, xtIter, xAccess.getDouble(xIter), right, inside, leftOpenProfiled, previous);
}
}
result[i++] = previous;
}
return vectorFactory.createIntVector(result, complete);
}
}
@Specialization(replaces = "doFindInterval")
RAbstractIntVector doFindIntervalGeneric(RAbstractDoubleVector xt, RAbstractDoubleVector x, boolean right, boolean inside, boolean leftOpen,
@Cached("createEqualityProfile()") ValueProfile leftOpenProfile,
@Cached("create()") VectorFactory factory) {
return doFindInterval(xt, x, right, inside, leftOpen, leftOpenProfile, xt.slowPathAccess(), x.slowPathAccess(), factory);
}
// transcribed from appl/interv.c
private int findInterval2(VectorAccess xtAccess, RandomIterator xtIter, double x, boolean right, boolean inside, boolean leftOpen, int iloIn) {
int n = xtAccess.getLength(xtIter);
if (n == 0) {
return 0;
}
// Note: GNUR code is written with 1-based indexing (by shifting the pointer), we subtract
// one in each vector access
int ilo = iloIn;
if (ilo <= 0) {
double xt0 = xtAccess.getDouble(xtIter, 0);
if (xsmlr(leftOpen, x, xt0)) {
return leftBoundary(right, inside, x, xt0);
}
ilo = 1;
}
int ihi = ilo + 1;
if (ihi >= n) {
double xtLast = xtAccess.getDouble(xtIter, n - 1);
if (xgrtr(leftOpen, x, xtLast)) {
return rightBoundary(right, inside, x, xtLast, n);
}
if (n <= 1) {
/* x < xt[1] */
return leftBoundary(right, inside, x, xtAccess.getDouble(xtIter, 0));
}
ilo = n - 1;
ihi = n;
}
if (xsmlr(leftOpen, x, xtAccess.getDouble(xtIter, ihi - 1))) {
if (xgrtr(leftOpen, x, xtAccess.getDouble(xtIter, ilo - 1))) {
/* `lucky': same interval as last time */
return ilo;
}
/* **** now x < xt[ilo] . decrease ilo to capture x */
int istep = 1;
boolean done = false;
while (true) {
ihi = ilo;
ilo = ihi - istep;
if (ilo <= 1) {
break;
}
double xtilo = xtAccess.getDouble(xtIter, ilo - 1);
if ((leftOpen && x > xtilo) || (!leftOpen && x >= xtilo)) {
done = true;
break;
}
istep *= 2;
}
if (!done) {
ilo = 1;
double xt0 = xtAccess.getDouble(xtIter, 0);
if (xsmlr(leftOpen, x, xt0)) {
return leftBoundary(right, inside, x, xt0);
}
}
} else {
/* **** now x >= xt[ihi] . increase ihi to capture x */
int istep = 1;
boolean done = false;
while (true) {
ilo = ihi;
ihi = ilo + istep;
if (ihi >= n) {
break;
}
double xtihi = xtAccess.getDouble(xtIter, ihi - 1);
if ((leftOpen && x <= xtihi) || (!leftOpen && x < xtihi)) {
done = true;
break;
}
istep *= 2;
}
if (!done) {
double xtLast = xtAccess.getDouble(xtIter, n - 1);
if (xgrtr(leftOpen, x, xtLast)) {
return rightBoundary(right, inside, x, xtLast, n);
}
ihi = n;
}
}
// L50 and L51 in the original GNUR source differ only by ">" vs ">=" depending on leftOpen
assert ilo <= ihi;
while (true) {
int middle = (ilo + ihi) / 2;
if (middle == ilo) {
return ilo;
}
/* note. it is assumed that middle = ilo in case ihi = ilo+1 . */
double xtMiddle = xtAccess.getDouble(xtIter, middle - 1);
if ((!leftOpen && x >= xtMiddle) || (leftOpen && x > xtMiddle)) {
ilo = middle;
} else {
ihi = middle;
}
}
}
private static int leftBoundary(boolean right, boolean inside, double x, double xt0) {
return ((inside || (right && x == xt0)) ? 1 : 0);
}
private static int rightBoundary(boolean right, boolean inside, double x, double xtLast, int n) {
return ((inside || (right && x == xtLast)) ? (n - 1) : n);
}
private static boolean xsmlr(boolean leftOpen, double x, double val) {
return x < val || (leftOpen && x <= val);
}
private static boolean xgrtr(boolean leftOpen, double x, double val) {
return x > val || (!leftOpen && x >= val);
}
}
......@@ -26275,21 +26275,37 @@ logical(0)
#argv <- list(character(0), character(0)); .Internal(file.rename(argv[[1]], argv[[2]]))
logical(0)
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval1#Ignored.Unimplemented#
#argv <- list(c(1, 2, 3, 4, 5, 6, 7, 8, 9), c(3, 3.25, 3.5, 3.75, 4, 4.25, 4.5, 4.75, 5, 5.25, 5.5, 5.75, 6), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))
Error: 4 arguments passed to .Internal(findInterval) which requires 5
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval1#
#argv <- list(c(1, 2, 3, 4, 5, 6, 7, 8, 9), c(3, 3.25, 3.5, 3.75, 4, 4.25, 4.5, 4.75, 5, 5.25, 5.5, 5.75, 6), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))
[1] 3 3 3 3 4 4 4 4 5 5 5 5 6
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval2#Ignored.Unimplemented#
#argv <- list(NA_real_, NA_real_, FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))
Error: 4 arguments passed to .Internal(findInterval) which requires 5
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval1#
#argv <- list(c(1, 2, 3, 4, 5, 6, 7, 8, 9), c(3, 3.25, 3.5, 3.75, 4, 4.25, 4.5, 4.75, 5, 5.25, 5.5, 5.75, 6), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))
[1] 2 3 3 3 3 4 4 4 4 5 5 5 5
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval3#Ignored.Unimplemented#
#argv <- list(numeric(0), numeric(0), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))
Error: 4 arguments passed to .Internal(findInterval) which requires 5
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval2#
#argv <- list(NA_real_, NA_real_, FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))
[1] NA
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval2#
#argv <- list(NA_real_, NA_real_, FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))
[1] NA
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval3#
#argv <- list(numeric(0), numeric(0), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))
integer(0)
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval3#
#argv <- list(numeric(0), numeric(0), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))
integer(0)
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval4#
#argv <- list(c(5, 10, 15), c(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))
[1] 0 0 0 1 1 1 1 1 2 2 2 2 2 3 3 3 3
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval4#Ignored.Unimplemented#
#argv <- list(c(5, 10, 15), c(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))
Error: 4 arguments passed to .Internal(findInterval) which requires 5
##com.oracle.truffle.r.test.builtins.TestBuiltin_findInterval.testfindInterval4#
#argv <- list(c(5, 10, 15), c(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))
[1] 0 0 0 0 1 1 1 1 1 2 2 2 2 2 3 3 3
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_floor.testFloor#
#if (!any(R.version$engine == "FastR")) { 1+1i } else { { floor(1.1+1.9i); } }
......@@ -19,27 +19,25 @@ public class TestBuiltin_findInterval extends TestBase {
@Test
public void testfindInterval1() {
// FIXME RInternalError: not implemented: .Internal findInterval
assertEval(Ignored.Unimplemented,
"argv <- list(c(1, 2, 3, 4, 5, 6, 7, 8, 9), c(3, 3.25, 3.5, 3.75, 4, 4.25, 4.5, 4.75, 5, 5.25, 5.5, 5.75, 6), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))");
assertEval("argv <- list(c(1, 2, 3, 4, 5, 6, 7, 8, 9), c(3, 3.25, 3.5, 3.75, 4, 4.25, 4.5, 4.75, 5, 5.25, 5.5, 5.75, 6), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))");
assertEval("argv <- list(c(1, 2, 3, 4, 5, 6, 7, 8, 9), c(3, 3.25, 3.5, 3.75, 4, 4.25, 4.5, 4.75, 5, 5.25, 5.5, 5.75, 6), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))");
}
@Test
public void testfindInterval2() {
// FIXME RInternalError: not implemented: .Internal findInterval
assertEval(Ignored.Unimplemented, "argv <- list(NA_real_, NA_real_, FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))");
assertEval("argv <- list(NA_real_, NA_real_, FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))");
assertEval("argv <- list(NA_real_, NA_real_, FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))");
}
@Test
public void testfindInterval3() {
// FIXME RInternalError: not implemented: .Internal findInterval
assertEval(Ignored.Unimplemented, "argv <- list(numeric(0), numeric(0), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))");
assertEval("argv <- list(numeric(0), numeric(0), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))");
assertEval("argv <- list(numeric(0), numeric(0), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))");
}
@Test
public void testfindInterval4() {
// FIXME RInternalError: not implemented: .Internal findInterval
assertEval(Ignored.Unimplemented,
"argv <- list(c(5, 10, 15), c(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]]))");
assertEval("argv <- list(c(5, 10, 15), c(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], T))");
assertEval("argv <- list(c(5, 10, 15), c(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), FALSE, FALSE); .Internal(findInterval(argv[[1]], argv[[2]], argv[[3]], argv[[4]], F))");
}
}
......@@ -102,6 +102,7 @@ com.oracle.truffle.r.native/run/Rclasspath.sh,oracle_bash.copyright
com.oracle.truffle.r.native/run/Rscript_exec.sh,oracle_bash.copyright
com.oracle.truffle.r.native/run/Rscript.sh,oracle_bash.copyright
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Abbrev.java,gnu_r_gentleman_ihaka2.copyright
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/FindInterval.java,gnu_r.core.copyright
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/APerm.java,purdue.copyright
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BaseGammaFunctions.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BaseVariables.java,gnu_r.copyright
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment