From aab32f767d539792c741885fbfddec4b07008758 Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Wed, 14 Dec 2016 16:29:30 +0100 Subject: [PATCH] adapt and extends tests for special calls --- .../truffle/r/nodes/test/SpecialCallTest.java | 166 ++++++++++++------ 1 file changed, 112 insertions(+), 54 deletions(-) diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java index b72e022b32..b92c005198 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java @@ -28,11 +28,6 @@ import org.junit.Test; import com.oracle.truffle.api.RootCallTarget; import com.oracle.truffle.api.source.Source; import com.oracle.truffle.r.engine.TruffleRLanguage; -import com.oracle.truffle.r.nodes.access.WriteVariableSyntaxNode; -import com.oracle.truffle.r.nodes.control.BlockNode; -import com.oracle.truffle.r.nodes.control.ReplacementDispatchNode; -import com.oracle.truffle.r.nodes.function.RCallNode; -import com.oracle.truffle.r.nodes.function.RCallSpecialNode; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.FastROptions; import com.oracle.truffle.r.runtime.RError; @@ -59,12 +54,23 @@ public class SpecialCallTest extends TestBase { @Override protected Void visit(RSyntaxCall element) { - if (element instanceof RCallSpecialNode) { - special++; - } else if (element instanceof RCallNode) { - normal++; - } else { - assert element instanceof ReplacementDispatchNode || element instanceof WriteVariableSyntaxNode || element instanceof BlockNode : "unexpected node while testing"; + switch (element.getClass().getSimpleName()) { + case "SpecialReplacementNode": + case "SpecialVoidReplacementNode": + case "RCallSpecialNode": + special++; + break; + case "RCallNodeGen": + case "GenericReplacementNode": + normal++; + break; + case "ReplacementDispatchNode": + case "WriteVariableSyntaxNode": + case "BlockNode": + // ignored + break; + default: + throw new AssertionError("unexpected class: " + element.getClass().getSimpleName()); } accept(element.getSyntaxLHS()); for (RSyntaxElement arg : element.getSyntaxArguments()) { @@ -158,81 +164,133 @@ public class SpecialCallTest extends TestBase { assertCallCounts("1 + 1", 1, 0, 1, 0); assertCallCounts("1 + 1 * 2 + 4", 3, 0, 3, 0); - assertCallCounts("{ a <- 1; b <- 2; a + b }", 1, 0, 1, 0); - assertCallCounts("{ a <- 1; b <- 2; c <- 3; a + b * 2 * c}", 3, 0, 3, 0); + assertCallCounts("{ a <- 1; b <- 2 }", "a + b", 1, 0, 1, 0); + assertCallCounts("{ a <- 1; b <- 2; c <- 3 }", "a + b * 2 * c", 3, 0, 3, 0); - assertCallCounts("{ a <- data.frame(a=1); b <- 2; c <- 3; a + b * 2 * c}", 3, 1, 2, 2); - assertCallCounts("{ a <- 1; b <- data.frame(a=1); c <- 3; a + b * 2 * c}", 3, 1, 0, 4); + assertCallCounts("{ a <- data.frame(a=1); b <- 2; c <- 3 }", "a + b * 2 * c", 3, 0, 2, 1); + assertCallCounts("{ a <- 1; b <- data.frame(a=1); c <- 3 }", "a + b * 2 * c", 3, 0, 0, 3); assertCallCounts("1 %*% 1", 0, 1, 0, 1); } @Test public void testSubset() { - assertCallCounts("{ a <- 1:10; a[1] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[2] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[4] }", 1, 1, 1, 1); - assertCallCounts("{ a <- list(c(1,2,3,4),2,3); a[1] }", 1, 2, 1, 2); - - assertCallCounts("{ a <- c(1,2,3,4); a[0.1] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[5] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[0] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[-1] }", 0, 3, 0, 3); // "-1" is a unary expression - assertCallCounts("{ a <- c(1,2,3,4); b <- -1; a[b] }", 1, 2, 0, 3); - assertCallCounts("{ a <- c(1,2,3,4); a[NA_integer_] }", 1, 1, 0, 2); + assertCallCounts("a <- 1:10", "a[1]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[2]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[4]", 1, 0, 1, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[1]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[0.1]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[5]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[0]", 1, 0, 1, 0); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[b]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[NA_integer_]", 1, 0, 1, 0); + + assertCallCounts("a <- c(1,2,3,4)", "a[-1]", 0, 2, 0, 2); // "-1" is a unary expression + assertCallCounts("a <- c(1,2,3,4)", "a[drop=T, 1]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[drop=F, 1]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[1, drop=F]", 0, 1, 0, 1); } @Test public void testSubscript() { - assertCallCounts("{ a <- 1:10; a[[1]] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[[2]] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[[4]] }", 1, 1, 1, 1); - assertCallCounts("{ a <- list(c(1,2,3,4),2,3); a[[1]] }", 1, 2, 1, 2); - assertCallCounts("{ a <- list(a=c(1,2,3,4),2,3); a[[1]] }", 1, 2, 1, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[[0.1]] }", 1, 1, 1, 1); - - assertCallCounts("{ a <- c(1,2,3,4); a[[5]] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[[0]] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); b <- -1; a[[b]] }", 1, 2, 0, 3); - assertCallCounts("{ a <- c(1,2,3,4); a[[NA_integer_]] }", 1, 1, 0, 2); + assertCallCounts("a <- 1:10", "a[[1]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[2]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[4]]", 1, 0, 1, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[[1]]", 1, 0, 1, 0); + assertCallCounts("a <- list(a=c(1,2,3,4),2,3)", "a[[1]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[0.1]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[5]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[0]]", 1, 0, 1, 0); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[[b]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[NA_integer_]]", 1, 0, 1, 0); + + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=T, 1]]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=F, 1]]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[1, drop=F]]", 0, 1, 0, 1); } - private static void assertCallCounts(String str, int initialSpecialCount, int initialNormalCount, int finalSpecialCount, int finalNormalCount) { + @Test + public void testUpdateSubset() { + assertCallCounts("a <- 1:10", "a[1] <- 1", 1, 0, 1, 1); // sequence + assertCallCounts("a <- c(1,2,3,4)", "a[2] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[4] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[1] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[0.1] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[5] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[0] <- 1", 1, 0, 1, 1); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[b] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[NA_integer_] <- 1", 1, 0, 1, 1); + + assertCallCounts("a <- c(1,2,3,4)", "a[-1] <- 1", 0, 2, 0, 3); // "-1" is a unary expression + assertCallCounts("a <- c(1,2,3,4)", "a[drop=T, 1] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[drop=F, 1] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[1, drop=F] <- 1", 0, 1, 0, 2); + } + + @Test + public void testUpdateSubscript() { + assertCallCounts("a <- 1:10", "a[[1]] <- 1", 1, 0, 1, 1); // sequence + assertCallCounts("a <- c(1,2,3,4)", "a[[2]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[4]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- list(a=c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[0.1]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[5]] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[0]] <- 1", 1, 0, 1, 1); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[[b]] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[NA_integer_]] <- 1", 1, 0, 1, 1); + + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=T, 1]] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=F, 1]] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[[1, drop=F]] <- 1", 0, 1, 0, 2); + } + + private static void assertCallCounts(String test, int initialSpecialCount, int initialNormalCount, int finalSpecialCount, int finalNormalCount) { + assertCallCounts("{}", test, initialSpecialCount, initialNormalCount, finalSpecialCount, finalNormalCount); + } + + private static void assertCallCounts(String setup, String test, int initialSpecialCount, int initialNormalCount, int finalSpecialCount, int finalNormalCount) { if (!FastROptions.UseSpecials.getBooleanValue()) { return; } - Source source = Source.newBuilder(str).mimeType(TruffleRLanguage.MIME).name("test").build(); + Source setupSource = Source.newBuilder(setup).mimeType(TruffleRLanguage.MIME).name("test").build(); + Source testSource = Source.newBuilder(test).mimeType(TruffleRLanguage.MIME).name("test").build(); - RExpression expression = testVMContext.getThisEngine().parse(source); - assert expression.getLength() == 1; - RootCallTarget callTarget = testVMContext.getThisEngine().makePromiseCallTarget(((RLanguage) expression.getDataAt(0)).getRep().asRSyntaxNode().asRNode(), "test"); + RExpression setupExpression = testVMContext.getThisEngine().parse(setupSource); + RExpression testExpression = testVMContext.getThisEngine().parse(testSource); + assert setupExpression.getLength() == 1; + assert testExpression.getLength() == 1; + RootCallTarget setupCallTarget = testVMContext.getThisEngine().makePromiseCallTarget(((RLanguage) setupExpression.getDataAt(0)).getRep().asRSyntaxNode().asRNode(), "test"); + RootCallTarget testCallTarget = testVMContext.getThisEngine().makePromiseCallTarget(((RLanguage) testExpression.getDataAt(0)).getRep().asRSyntaxNode().asRNode(), "test"); try { - CountCallsVisitor count1 = new CountCallsVisitor(callTarget); - Assert.assertEquals("initial special call count '" + str + "': ", initialSpecialCount, count1.special); - Assert.assertEquals("initial normal call count '" + str + "': ", initialNormalCount, count1.normal); + CountCallsVisitor count1 = new CountCallsVisitor(testCallTarget); + Assert.assertEquals("initial special call count '" + setup + "; " + test + "': ", initialSpecialCount, count1.special); + Assert.assertEquals("initial normal call count '" + setup + "; " + test + "': ", initialNormalCount, count1.normal); try { - callTarget.call(REnvironment.globalEnv().getFrame()); + setupCallTarget.call(REnvironment.globalEnv().getFrame()); + testCallTarget.call(REnvironment.globalEnv().getFrame()); } catch (RError e) { // ignore } - CountCallsVisitor count2 = new CountCallsVisitor(callTarget); - Assert.assertEquals("special call count after first call '" + str + "': ", finalSpecialCount, count2.special); - Assert.assertEquals("normal call count after first call '" + str + "': ", finalNormalCount, count2.normal); + CountCallsVisitor count2 = new CountCallsVisitor(testCallTarget); + Assert.assertEquals("special call count after first call '" + setup + "; " + test + "': ", finalSpecialCount, count2.special); + Assert.assertEquals("normal call count after first call '" + setup + "; " + test + "': ", finalNormalCount, count2.normal); try { - callTarget.call(REnvironment.globalEnv().getFrame()); + setupCallTarget.call(REnvironment.globalEnv().getFrame()); + testCallTarget.call(REnvironment.globalEnv().getFrame()); } catch (RError e) { // ignore } - CountCallsVisitor count3 = new CountCallsVisitor(callTarget); - Assert.assertEquals("special call count after second call '" + str + "': ", finalSpecialCount, count3.special); - Assert.assertEquals("normal call count after second call '" + str + "': ", finalNormalCount, count3.normal); + CountCallsVisitor count3 = new CountCallsVisitor(testCallTarget); + Assert.assertEquals("special call count after second call '" + setup + "; " + test + "': ", finalSpecialCount, count3.special); + Assert.assertEquals("normal call count after second call '" + setup + "; " + test + "': ", finalNormalCount, count3.normal); } catch (AssertionError e) { - new PrintCallsVisitor().print(callTarget); + new PrintCallsVisitor().print(testCallTarget); throw e; } } -- GitLab