From 695aa6c5ed468983a8c010df010c1969d90d1a6a Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Thu, 22 Mar 2018 12:05:52 +0100 Subject: [PATCH] assign proper env to "deriv" results --- .../truffle/r/library/stats/deriv/Deriv.java | 40 ++++++++----------- .../truffle/r/test/ExpectedTestOutput.test | 14 ++++++- .../r/test/builtins/TestBuiltin_deriv.java | 7 ++++ 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/deriv/Deriv.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/deriv/Deriv.java index e4b7de9ed1..5bb5eee332 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/deriv/Deriv.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/deriv/Deriv.java @@ -112,12 +112,12 @@ public abstract class Deriv extends RExternalBuiltinNode { return DerivNodeGen.create(); } - public abstract Object execute(VirtualFrame frame, Object arg1, Object arg2, Object arg3, Object arg4, Object arg5); + public abstract Object execute(Object arg1, Object arg2, Object arg3, Object arg4, Object arg5); @Override public Object call(VirtualFrame frame, RArgsValuesAndNames args) { checkLength(args, 5); - return execute(frame, castArg(args, 0), castArg(args, 1), castArg(args, 2), castArg(args, 3), castArg(args, 4)); + return execute(castArg(args, 0), castArg(args, 1), castArg(args, 2), castArg(args, 3), castArg(args, 4)); } @Override @@ -130,43 +130,34 @@ public abstract class Deriv extends RExternalBuiltinNode { } @Specialization(guards = "isConstant(expr)") - protected Object derive(VirtualFrame frame, Object expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { - return deriveInternal(frame.materialize(), createConstant(expr), names, functionArg, tag, hessian); - } - - @TruffleBoundary - private static RSyntaxNode createConstant(Object expr) { - return RContext.getASTBuilder().constant(RSyntaxNode.LAZY_DEPARSE, expr); + protected Object derive(Object expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { + return deriveInternal(RSyntaxConstant.createDummyConstant(RSyntaxNode.INTERNAL, expr), names, functionArg, tag, hessian); } @Specialization - protected Object derive(VirtualFrame frame, RSymbol expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { - return deriveInternal(frame.materialize(), createLookup(expr), names, functionArg, tag, hessian); - } - - @TruffleBoundary - private static RSyntaxElement createLookup(RSymbol expr) { - return RContext.getASTBuilder().lookup(RSyntaxNode.LAZY_DEPARSE, expr.getName(), false); + protected Object derive(RSymbol expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { + return deriveInternal(RSyntaxLookup.createDummyLookup(RSyntaxNode.INTERNAL, expr.getName(), false), names, functionArg, tag, hessian); } @Specialization - protected Object derive(VirtualFrame frame, RExpression expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian, + protected Object derive(RExpression expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian, @Cached("create()") Deriv derivNode) { - return derivNode.execute(frame, expr.getDataAt(0), names, functionArg, tag, hessian); + return derivNode.execute(expr.getDataAt(0), names, functionArg, tag, hessian); } @Specialization(guards = "expr.isLanguage()") - protected Object derive(VirtualFrame frame, RPairList expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { - return deriveInternal(frame.materialize(), extracted(expr), names, functionArg, tag, hessian); + protected Object derive(RPairList expr, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { + return deriveInternal(getSyntaxElement(expr), names, functionArg, tag, hessian); } - private RSyntaxElement extracted(RPairList expr) { + @TruffleBoundary + private static RSyntaxElement getSyntaxElement(RPairList expr) { return expr.getSyntaxElement(); } @TruffleBoundary - private Object deriveInternal(MaterializedFrame frame, RSyntaxElement elem, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { - return findDerive(elem, names, functionArg, tag, hessian).getResult(frame.materialize(), getRLanguage()); + private Object deriveInternal(RSyntaxElement elem, RAbstractStringVector names, Object functionArg, String tag, boolean hessian) { + return findDerive(elem, names, functionArg, tag, hessian).getResult(getRLanguage(), functionArg); } private static final class DerivResult { @@ -186,10 +177,11 @@ public abstract class Deriv extends RExternalBuiltinNode { result = null; } - private Object getResult(MaterializedFrame frame, TruffleRLanguage language) { + private Object getResult(TruffleRLanguage language, Object functionArg) { if (result != null) { return result; } + MaterializedFrame frame = functionArg instanceof RFunction ? ((RFunction) functionArg).getEnclosingFrame() : RContext.getInstance().stateREnvironment.getGlobalFrame(); RootCallTarget callTarget = RContext.getASTBuilder().rootFunction(language, RSyntaxNode.LAZY_DEPARSE, targetArgs, blockCall, null); FrameSlotChangeMonitor.initializeEnclosingFrame(callTarget.getRootNode().getFrameDescriptor(), frame); return RDataFactory.createFunction(RFunction.NO_NAME, RFunction.NO_NAME, callTarget, null, frame); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index e43dafd831..dc4af52603 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -21412,6 +21412,18 @@ attr(,"gradient") x [1,] NaN +##com.oracle.truffle.r.test.builtins.TestBuiltin_deriv.testEnvironment# +#environment(deriv((y ~ sin(cos(x) * y)), c('x','y'), func = TRUE)) +<environment: R_GlobalEnv> + +##com.oracle.truffle.r.test.builtins.TestBuiltin_deriv.testEnvironment# +#environment(local(deriv((y ~ sin(cos(x) * y)), c('x','y'), func = TRUE))) +<environment: R_GlobalEnv> + +##com.oracle.truffle.r.test.builtins.TestBuiltin_deriv.testEnvironment# +#environment(local(deriv((y ~ sin(cos(x) * y)), c('x','y'), func = deriv))) +<environment: namespace:stats> + ##com.oracle.truffle.r.test.builtins.TestBuiltin_deriv.testFunctionGenereration#Output.IgnoreWhitespace# #(df <- deriv(~x^2*sin(x), "x", function.arg=TRUE));df(0) function (x) @@ -28269,7 +28281,7 @@ Error: atomic vector arguments only function (x = 4) x + 1 -##com.oracle.truffle.r.test.builtins.TestBuiltin_function.testFunctionFunction# +##com.oracle.truffle.r.test.builtins.TestBuiltin_function.testFunctionFunction#Output.MayIgnoreErrorContext# #eval(call('function', 1, expression(x + 1)[[1]])) Error in eval(call("function", 1, expression(x + 1)[[1]])) : invalid formal argument list for "function" diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_deriv.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_deriv.java index 190994eb99..d09509aa25 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_deriv.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_deriv.java @@ -190,6 +190,13 @@ public class TestBuiltin_deriv extends TestBase { assertEval(Output.IgnoreWhitespace, "(df <- deriv(~x^2*sin(x), \"x\", function.arg=function(x)NULL));df(0)"); } + @Test + public void testEnvironment() { + assertEval("environment(deriv((y ~ sin(cos(x) * y)), c('x','y'), func = TRUE))"); + assertEval("environment(local(deriv((y ~ sin(cos(x) * y)), c('x','y'), func = TRUE)))"); + assertEval("environment(local(deriv((y ~ sin(cos(x) * y)), c('x','y'), func = deriv)))"); + } + @Test public void testUnusualExprs() { assertEval("(df <- deriv(expression(x^2*sin(x)), \"x\"));df(0)"); -- GitLab