Skip to content
Snippets Groups Projects
Commit 97949433 authored by Zbynek Slajchrt's avatar Zbynek Slajchrt
Browse files

Deriv builtin implementation

parent 0e7124b4
No related branches found
No related tags found
No related merge requests found
Showing
with 4050 additions and 16 deletions
/*
* Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.library.stats.deriv;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notEmpty;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.size;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.typeName;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.RASTUtils;
import com.oracle.truffle.r.nodes.access.ConstantNode;
import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RExpression;
import com.oracle.truffle.r.runtime.data.RLanguage;
import com.oracle.truffle.r.runtime.data.RSymbol;
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
import com.oracle.truffle.r.runtime.nodes.RSyntaxConstant;
import com.oracle.truffle.r.runtime.nodes.RSyntaxElement;
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
import com.oracle.truffle.r.runtime.nodes.RSyntaxVisitor;
public abstract class D extends RExternalBuiltinNode.Arg2 {
static {
Casts casts = new Casts(D.class);
casts.arg(0, "expr").mustBe(instanceOf(RExpression.class).or(instanceOf(RLanguage.class)).or(instanceOf(RSymbol.class)).or(numericValue()).or(complexValue()),
RError.Message.INVALID_EXPRESSION_TYPE, typeName());
casts.arg(1, "namevec").mustBe(stringValue()).asStringVector().mustBe(notEmpty(), RError.Message.GENERIC, "variable must be a character string").shouldBe(size(1),
RError.Message.ONLY_FIRST_VARIABLE_NAME).findFirst();
}
public static D create() {
return DNodeGen.create();
}
protected static boolean isConstant(Object expr) {
return !(expr instanceof RLanguage || expr instanceof RExpression || expr instanceof RSymbol);
}
@Specialization(guards = "isConstant(expr)")
@TruffleBoundary
protected Object doD(Object expr, String var) {
return doD(ConstantNode.create(expr), var);
}
@Specialization
@TruffleBoundary
protected Object doD(RSymbol expr, String var) {
return doD(RContext.getASTBuilder().lookup(RSyntaxNode.LAZY_DEPARSE, expr.getName(), false), var);
}
@Specialization
@TruffleBoundary
protected Object doD(RLanguage expr, String var) {
return doD((RSyntaxElement) expr.getRep(), var);
}
@Specialization
@TruffleBoundary
protected Object doD(RExpression expr, String var,
@Cached("create()") D dNode) {
return dNode.execute(expr.getDataAt(0), var);
}
private static Object doD(RSyntaxElement elem, String var) {
RSyntaxVisitor<RSyntaxElement> vis = new DerivVisitor(var);
RSyntaxElement dExpr = vis.accept(elem);
dExpr = Deriv.addParens(dExpr);
return RASTUtils.createLanguageElement(dExpr);
}
}
/*
* This material is distributed under the GNU General Public License
* Version 2. You may review the terms of this license at
* http://www.gnu.org/licenses/gpl-2.0.html
*
* Copyright (c) 1995, 1996 Robert Gentleman and Ross Ihaka
* Copyright (c) 1997-2013, The R Core Team
* Copyright (c) 2015, 2017, Oracle and/or its affiliates
*
* All rights reserved.
*/
package com.oracle.truffle.r.library.stats.deriv;
import static com.oracle.truffle.r.library.stats.deriv.Deriv.*;
import com.oracle.truffle.r.nodes.access.ConstantNode;
import com.oracle.truffle.r.runtime.RDeparse;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.Utils;
import com.oracle.truffle.r.runtime.nodes.RSyntaxCall;
import com.oracle.truffle.r.runtime.nodes.RSyntaxConstant;
import com.oracle.truffle.r.runtime.nodes.RSyntaxElement;
import com.oracle.truffle.r.runtime.nodes.RSyntaxFunction;
import com.oracle.truffle.r.runtime.nodes.RSyntaxLookup;
import com.oracle.truffle.r.runtime.nodes.RSyntaxVisitor;
//Transcribed from GnuR, library/stats/src/deriv.c
public class DerivVisitor extends RSyntaxVisitor<RSyntaxElement> {
private final String var;
DerivVisitor(String var) {
this.var = var;
}
@Override
protected RSyntaxElement visit(RSyntaxCall call) {
String functionName = getFunctionName(call);
assert Utils.isInterned(functionName);
RSyntaxElement arg0 = call.getSyntaxArguments()[0];
RSyntaxElement arg1 = call.getSyntaxArguments().length > 1 ? call.getSyntaxArguments()[1] : null;
if (functionName == LEFT_PAREN) {
return accept(arg0);
}
if (functionName == PLUS) {
if (call.getSyntaxArguments().length == 1) {
return accept(arg0);
} else {
return simplify(PLUS, accept(arg0), accept(arg1));
}
}
if (functionName == MINUS) {
if (call.getSyntaxArguments().length == 1) {
return simplify(MINUS, accept(arg0), null);
} else {
return simplify(MINUS, accept(arg0), accept(arg1));
}
}
if (functionName == TIMES) {
return simplify(PLUS, simplify(TIMES, accept(arg0), cloneElement(arg1)),
simplify(TIMES, cloneElement(arg0), accept(arg1)));
}
if (functionName == DIVIDE) {
return simplify(MINUS,
simplify(DIVIDE, accept(arg0), cloneElement(arg1)),
simplify(DIVIDE,
simplify(TIMES, cloneElement(arg0), accept(arg1)),
simplify(POWER, cloneElement(arg1), ConstantNode.create(2.))));
}
if (functionName == POWER) {
if (isNumeric(arg1)) {
return simplify(TIMES,
arg1,
simplify(TIMES,
accept(arg0),
simplify(POWER, cloneElement(arg0), decDouble(arg1))));
} else {
// (a^b)' = a^(b-1).b.a' + a^b.log(a).b'
RSyntaxElement expr1 = simplify(TIMES,
simplify(POWER,
arg0,
simplify(MINUS, cloneElement(arg1), ConstantNode.create(1.))),
simplify(TIMES, cloneElement(arg1), accept(arg0)));
RSyntaxElement expr2 = simplify(TIMES,
simplify(POWER, cloneElement(arg0), cloneElement(arg1)),
simplify(TIMES,
simplify(LOG, cloneElement(arg0), null),
accept(arg1)));
return simplify(PLUS, expr1, expr2);
}
}
if (functionName == EXP) {
return simplify(TIMES, cloneElement(call), accept(arg0));
}
if (functionName == LOG) {
if (call.getSyntaxArguments().length != 1) {
throw RError.error(RError.SHOW_CALLER, RError.Message.GENERIC, "only single-argument calls are supported");
}
return simplify(DIVIDE, accept(arg0), cloneElement(arg0));
}
if (functionName == COS) {
return simplify(TIMES,
simplify(SIN, cloneElement(arg0), null),
simplify(MINUS, accept(arg0), null));
}
if (functionName == SIN) {
return simplify(TIMES,
simplify(COS, cloneElement(arg0), null),
accept(arg0));
}
if (functionName == TAN) {
return simplify(DIVIDE,
accept(arg0),
simplify(POWER,
simplify(COS, cloneElement(arg0), null),
ConstantNode.create(2.)));
}
if (functionName == COSH) {
return simplify(TIMES,
simplify(SINH, cloneElement(arg0), null),
accept(arg0));
}
if (functionName == SINH) {
return simplify(TIMES,
simplify(COSH, cloneElement(arg0), null),
accept(arg0));
}
if (functionName == TANH) {
return simplify(DIVIDE,
accept(arg0),
simplify(POWER,
simplify(COSH, cloneElement(arg0), null),
ConstantNode.create(2.)));
}
if (functionName == SQRT) {
return accept(simplify(POWER, cloneElement(arg0), ConstantNode.create(0.5)));
}
if (functionName == PNORM) {
return simplify(TIMES,
simplify(DNORM, cloneElement(arg0), null),
accept(arg0));
}
if (functionName == DNORM) {
return simplify(TIMES,
simplify(MINUS, cloneElement(arg0), null),
simplify(TIMES,
simplify(DNORM, cloneElement(arg0), null),
accept(arg0)));
}
if (functionName == ASIN) {
return simplify(DIVIDE,
accept(arg0),
simplify(SQRT,
simplify(MINUS,
ConstantNode.create(1.),
simplify(POWER, cloneElement(arg0), ConstantNode.create(2.))),
null));
}
if (functionName == ACOS) {
return simplify(MINUS,
simplify(DIVIDE,
accept(arg0),
simplify(SQRT,
simplify(MINUS,
ConstantNode.create(1.),
simplify(POWER, cloneElement(arg0), ConstantNode.create(2.))),
null)),
null);
}
if (functionName == ATAN) {
return simplify(DIVIDE,
accept(arg0),
simplify(PLUS,
ConstantNode.create(1.),
simplify(POWER, cloneElement(arg0), ConstantNode.create(2.))));
}
if (functionName == LGAMMA) {
return simplify(TIMES,
accept(arg0),
simplify(DIGAMMA, cloneElement(arg0), null));
}
if (functionName == GAMMA) {
return simplify(TIMES,
accept(arg0),
simplify(TIMES, cloneElement(call),
simplify(DIGAMMA, cloneElement(arg0), null)));
}
if (functionName == DIGAMMA) {
return simplify(TIMES,
accept(arg0),
simplify(TRIGAMMA, cloneElement(arg0), null));
}
if (functionName == TRIGAMMA) {
return simplify(TIMES,
accept(arg0),
simplify(PSIGAMMA, cloneElement(arg0), ConstantNode.create(2)));
}
if (functionName == PSIGAMMA) {
if (call.getSyntaxArguments().length == 1) {
return simplify(TIMES,
accept(arg0),
simplify(PSIGAMMA, cloneElement(arg0), ConstantNode.create(1)));
} else if (isIntegerOrDouble(arg1)) {
return simplify(TIMES,
accept(arg0),
simplify(PSIGAMMA, cloneElement(arg0), incInteger(arg1)));
} else {
return simplify(TIMES,
accept(arg0),
simplify(PSIGAMMA,
cloneElement(arg0),
simplify(PLUS, cloneElement(arg1), ConstantNode.create(1))));
}
}
throw RError.error(RError.SHOW_CALLER, RError.Message.NOT_IN_DERIVATIVE_TABLE, RDeparse.deparseSyntaxElement(call.getSyntaxLHS()));
}
@Override
protected RSyntaxElement visit(RSyntaxConstant element) {
return ConstantNode.create(0.);
}
@Override
protected RSyntaxElement visit(RSyntaxLookup element) {
double dVal = element.getIdentifier().equals(var) ? 1 : 0;
return ConstantNode.create(dVal);
}
@Override
protected RSyntaxElement visit(RSyntaxFunction element) {
throw RInternalError.shouldNotReachHere();
}
private RSyntaxElement simplify(String functionName, RSyntaxElement arg1, RSyntaxElement arg2) {
if (functionName == PLUS) {
if (arg2 == null) {
return arg1;
} else if (isZero(arg1)) {
return arg2;
} else if (isZero(arg2)) {
return arg1;
} else if (isUminus(arg1)) {
return simplify(MINUS, arg2, arg(arg1, 0));
} else if (isUminus(arg2)) {
return simplify(MINUS, arg1, arg(arg2, 0));
} else {
return newCall(PLUS, arg1, arg2);
}
} else if (functionName == MINUS) {
if (arg2 == null) {
if (isZero(arg1)) {
return ConstantNode.create(0.);
} else if (isUminus(arg1)) {
return arg(arg1, 0);
} else {
return newCall(MINUS, arg1, arg2);
}
} else {
if (isZero(arg2)) {
return arg1;
} else if (isZero(arg1)) {
return simplify(MINUS, arg2, null);
} else if (isUminus(arg1)) {
return simplify(MINUS,
simplify(PLUS, arg(arg1, 0), arg2),
null);
} else if (isUminus(arg2)) {
return simplify(PLUS, arg1, arg(arg2, 0));
} else {
return newCall(MINUS, arg1, arg2);
}
}
} else if (functionName == TIMES) {
if (isZero(arg1) || isZero(arg2)) {
return ConstantNode.create(0.);
} else if (isOne(arg1)) {
return arg2;
} else if (isOne(arg2)) {
return arg1;
} else if (isUminus(arg1)) {
return simplify(MINUS, simplify(TIMES, arg(arg1, 0), arg2), null);
} else if (isUminus(arg2)) {
return simplify(MINUS, simplify(TIMES, arg1, arg(arg2, 0)), null);
} else {
return newCall(TIMES, arg1, arg2);
}
} else if (functionName == DIVIDE) {
if (isZero(arg1)) {
return ConstantNode.create(0.);
} else if (isZero(arg2)) {
return ConstantNode.create(RRuntime.DOUBLE_NA);
} else if (isOne(arg2)) {
return arg1;
} else if (isUminus(arg1)) {
return simplify(MINUS, simplify(DIVIDE, arg(arg1, 0), arg2), null);
} else if (isUminus(arg2)) {
return simplify(MINUS, simplify(DIVIDE, arg1, arg(arg2, 0)), null);
} else {
return newCall(DIVIDE, arg1, arg2);
}
} else if (functionName == POWER) {
if (isZero(arg2)) {
return ConstantNode.create(1.);
} else if (isZero(arg1)) {
return ConstantNode.create(0.);
} else if (isOne(arg1)) {
return ConstantNode.create(1.);
} else if (isOne(arg2)) {
return arg1;
} else {
return newCall(POWER, arg1, arg2);
}
} else if (functionName == EXP) {
// FIXME: simplify exp(lgamma( E )) = gamma( E )
return newCall(EXP, arg1, null);
} else if (functionName == LOG) {
// FIXME: simplify log(gamma( E )) = lgamma( E )
return newCall(LOG, arg1, null);
} else if (functionName == COS || functionName == SIN || functionName == TAN || functionName == COSH || functionName == SINH || functionName == TANH || functionName == SQRT ||
functionName == PNORM || functionName == DNORM || functionName == ASIN || functionName == ACOS || functionName == ATAN || functionName == GAMMA || functionName == LGAMMA ||
functionName == DIGAMMA || functionName == TRIGAMMA || functionName == PSIGAMMA) {
return newCall(functionName, arg1, arg2);
} else {
return ConstantNode.create(RRuntime.DOUBLE_NA);
}
}
private static boolean isIntegerOrDouble(RSyntaxElement elem) {
if (elem instanceof RSyntaxConstant) {
Object val = ((RSyntaxConstant) elem).getValue();
return val instanceof Integer || val instanceof Double;
} else {
return false;
}
}
private static boolean isUminus(RSyntaxElement elem) {
if (elem instanceof RSyntaxCall && MINUS == getFunctionName(elem)) {
RSyntaxElement[] args = ((RSyntaxCall) elem).getSyntaxArguments();
switch (args.length) {
case 1:
return true;
case 2:
return false;
default:
throw RError.error(RError.SHOW_CALLER, RError.Message.GENERIC, "invalid form in unary minus check");
}
} else {
return false;
}
}
private static boolean isNumeric(RSyntaxElement elem) {
if (elem instanceof RSyntaxConstant) {
Object val = ((RSyntaxConstant) elem).getValue();
return val instanceof Integer || val instanceof Double || val instanceof Byte;
} else {
return false;
}
}
private static RSyntaxConstant decDouble(RSyntaxElement elem) {
assert elem instanceof RSyntaxConstant;
assert ((RSyntaxConstant) elem).getValue() instanceof Number;
Number n = (Number) ((RSyntaxConstant) elem).getValue();
return ConstantNode.create(n.doubleValue() - 1);
}
private static RSyntaxConstant incInteger(RSyntaxElement elem) {
assert elem instanceof RSyntaxConstant;
assert ((RSyntaxConstant) elem).getValue() instanceof Number;
Number n = (Number) ((RSyntaxConstant) elem).getValue();
return ConstantNode.create(n.intValue() + 1);
}
}
......@@ -55,6 +55,8 @@ import com.oracle.truffle.r.library.stats.RandFunctionsNodes.RandFunction3Node;
import com.oracle.truffle.r.library.stats.SignrankFreeNode;
import com.oracle.truffle.r.library.stats.SplineFunctionsFactory.SplineCoefNodeGen;
import com.oracle.truffle.r.library.stats.SplineFunctionsFactory.SplineEvalNodeGen;
import com.oracle.truffle.r.library.stats.deriv.D;
import com.oracle.truffle.r.library.stats.deriv.Deriv;
import com.oracle.truffle.r.library.stats.StatsFunctionsNodes;
import com.oracle.truffle.r.library.stats.WilcoxFreeNode;
import com.oracle.truffle.r.library.tools.C_ParseRdNodeGen;
......@@ -744,6 +746,11 @@ public class CallAndExternalFunctions {
switch (name) {
case "compcases":
return new CompleteCases();
// stats
case "doD":
return D.create();
case "deriv":
return Deriv.create();
// utils
case "countfields":
return CountFieldsNodeGen.create();
......
......@@ -101,6 +101,7 @@ public class ExtBuiltinsList {
com.oracle.truffle.r.library.methods.MethodsListDispatchFactory.R_getGenericNodeGen.class,
com.oracle.truffle.r.library.methods.MethodsListDispatchFactory.R_nextMethodCallNodeGen.class,
com.oracle.truffle.r.library.methods.MethodsListDispatchFactory.R_externalPtrPrototypeObjectNodeGen.class,
com.oracle.truffle.r.library.stats.deriv.DerivNodeGen.class,
};
@SuppressWarnings("unchecked")
......
......@@ -22,6 +22,12 @@
*/
package com.oracle.truffle.r.nodes.test;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asLogicalVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.chain;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.findFirst;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.map;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
......@@ -30,6 +36,7 @@ import org.junit.Test;
import com.oracle.truffle.r.nodes.builtin.casts.Filter.AndFilter;
import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter;
import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.ScalarValue;
import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapByteToBoolean;
import com.oracle.truffle.r.nodes.builtin.casts.Filter.DoubleFilter;
import com.oracle.truffle.r.nodes.builtin.casts.Filter.MissingFilter;
import com.oracle.truffle.r.nodes.builtin.casts.Filter.NotFilter;
......@@ -42,6 +49,7 @@ import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.CoercionStep;
import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FilterStep;
import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FindFirstStep;
import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapIfStep;
import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapStep;
import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardedValuesAnalyser;
import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingAnalysisResult;
import com.oracle.truffle.r.runtime.RType;
......@@ -406,4 +414,26 @@ public class ForwardedValuesAnalyserTest {
assertFalse(result.isStringForwarded());
assertFalse(result.isMissingForwarded());
}
@Test
public void testReturnIfWithTrueBranchChain() {
//@formatter:off
PipelineStep<?, ?> findFirstBoolean = new CoercionStep<>(RType.Logical, false).setNext(new FindFirstStep<>(null, Byte.class, null)).setNext(new MapStep<>(new MapByteToBoolean(false)));
PipelineStep<?, ?> firstStep = new MapIfStep<>(new RTypeFilter<>(RType.Logical), // the condition
findFirstBoolean, null, true);
//@formatter:on
ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser();
ForwardingAnalysisResult result = fwdAn.analyse(firstStep);
// TODO: change it to the positive assertion when the selected mappers (such as
// MapByteToBoolean) are supported
assertFalse(result.isLogicalForwarded());
assertTrue(result.logicalForwarded.mapper instanceof MapByteToBoolean);
assertTrue(result.isDoubleForwarded());
assertTrue(result.isIntegerForwarded());
assertTrue(result.isNullForwarded());
assertTrue(result.isStringForwarded());
assertTrue(result.isMissingForwarded());
}
}
......@@ -431,13 +431,14 @@ public class RBuiltinDiagnostics {
print(1, "\nUnhandled argument combinations: " + nonCoveredArgsSet.size());
print(1, "");
printDeadSpecs();
if (diagSuite.diagConfig.verbose) {
for (List<Type> uncoveredArgs : nonCoveredArgsSet) {
print(1, uncoveredArgs.stream().map(t -> typeName(t)).collect(Collectors.toList()));
}
}
print(1, "");
printDeadSpecs();
}
private void printBuiltinHeader(int level) {
......
......@@ -426,6 +426,10 @@ public final class CastBuilder {
return new NotNAStep<>(null, null);
}
public static <T> PipelineStep<T, T> boxPrimitive() {
return new PipelineStep.BoxPrimitiveStep<>();
}
public static NullFilter nullValue() {
return NullFilter.INSTANCE;
}
......
......@@ -39,7 +39,7 @@ public abstract class ForwardingStatus {
};
public static final ForwardingStatus FORWARDED = new Forwarded(null);
final Mapper<?, ?> mapper;
public final Mapper<?, ?> mapper;
private final byte flag;
protected ForwardingStatus(byte flag, Mapper<?, ?> mapper) {
......@@ -97,7 +97,7 @@ public abstract class ForwardingStatus {
}
ForwardingStatus or(ForwardingStatus other) {
return fromFlag(or(this.flag, other.flag));
return fromFlag(or(this.flag, other.flag), this.mapper != null ? this.mapper : other.mapper);
}
ForwardingStatus not() {
......
......@@ -847,7 +847,12 @@ public final class RError extends RuntimeException {
TRUNCATE_NOT_ENABLED("truncation not enabled for this connection"),
TRUNCATE_UNSUPPORTED_FOR_CONN("cannot truncate connection: %s"),
INCOMPLETE_STRING_AT_EOF_DISCARDED("incomplete string at end of file has been discarded"),
INVALID_CHANNEL_OBJECT("invalid channel object type: %s");
INVALID_CHANNEL_OBJECT("invalid channel object type: %s"),
INVALID_TAG("invalid tag"),
INVALID_VARIABLE_NAMES("invalid variable names"),
INVALID_EXPRESSION("invalid expression in '%s'"),
INVALID_EXPRESSION_TYPE("expression must not be type '%s'"),
NOT_IN_DERIVATIVE_TABLE("Function '%s' is not in the derivatives table");
public final String message;
final boolean hasArgs;
......
/*
* Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
......@@ -197,4 +197,11 @@ public interface RCodeBuilder<T> {
default T call(SourceSection source, T lhs, T argument1, T argument2, T argument3) {
return call(source, lhs, Arrays.asList(argument(argument1), argument(argument2), argument(argument3)));
}
/**
* Helper function: create a call with four unnamed arguments.
*/
default T call(SourceSection source, T lhs, T argument1, T argument2, T argument3, T argument4) {
return call(source, lhs, Arrays.asList(argument(argument1), argument(argument2), argument(argument3), argument(argument4)));
}
}
/*
* Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.test.builtins;
import org.junit.Test;
import com.oracle.truffle.r.test.TestBase;
public class TestBuiltin_D extends TestBase {
@Test
public void testD() {
assertEval("(df <- D(expression(x^2*sin(x)), \"x\"));df(0)");
assertEval("(df <- D(quote(x^2*sin(x)), \"x\"));df(0)");
assertEval("g<-quote(x^2);(df <- D(g, \"x\"));df(0)");
assertEval("(df <- D(1, \"x\"));df(0)");
assertEval("x<-1;(df <- D(x, \"x\"));df(0)");
}
}
/*
* Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.test.builtins;
import org.junit.Test;
import com.oracle.truffle.r.test.TestBase;
import com.oracle.truffle.r.test.TestTrait;
public class TestBuiltin_deriv extends TestBase {
final class DerivExpr {
final String expr;
final String assertedExpr;
final int dn;
final boolean hessian;
DerivExpr(String expr, int dn) {
this(expr, dn, false);
}
DerivExpr(String expr, int dn, boolean hessian) {
this.expr = expr;
this.dn = dn;
this.hessian = hessian;
String vars = dn == 1 ? "c(\"x\")" : "c(\"x\",\"y\")";
String h = hessian ? "TRUE" : "FALSE";
this.assertedExpr = "deriv(~ " + expr + ", " + vars + ", hessian=" + h + ")";
}
DerivEval derive() {
assertEval(assertedExpr);
return new DerivEval(this);
}
DerivEval derive(TestTrait trait) {
assertEval(trait, assertedExpr);
return new DerivEval(this);
}
}
final class DerivEval {
final DerivExpr de;
final String assertedExpr;
DerivEval(DerivExpr de) {
this.de = de;
String vars = de.dn == 1 ? "x<-%s" : "x<-%s; y<-%s";
this.assertedExpr = "df <- " + de.assertedExpr + "; " + vars + "; eval(df)";
}
DerivEval eval(Object... vals) {
String ex = String.format(this.assertedExpr, vals);
assertEval(ex);
return this;
}
DerivEval eval(TestTrait trait, Object... vals) {
String ex = String.format(this.assertedExpr, vals);
assertEval(trait, ex);
return this;
}
DerivExpr withHessian() {
return new DerivExpr(de.expr, de.dn, true);
}
}
private DerivExpr deriv1(String expr) {
return new DerivExpr(expr, 1);
}
private DerivExpr deriv2(String expr) {
return new DerivExpr(expr, 2);
}
private void assertDerivAndEval1(String expr) {
deriv1(expr).derive().eval(0).withHessian().derive().eval(0);
}
private void assertDerivAndEval1(TestTrait trait, String expr) {
deriv1(expr).derive(trait).eval(0, 0);
}
private DerivEval assertDeriv1(String expr) {
return deriv1(expr).derive();
}
private void assertDerivAndEval2(String expr) {
deriv2(expr).derive().eval(0, 0).withHessian().derive().eval(0, 0);
}
@Test
public void testDeriveBasicExpressions1() {
assertDerivAndEval1("1");
assertDerivAndEval1("x");
assertDerivAndEval1("x+1");
assertDerivAndEval1("2*x");
assertDerivAndEval1("x/2");
assertDerivAndEval1(Ignored.OutputFormatting, "2/x");
assertDerivAndEval1("x^2");
assertDerivAndEval1("(x+1)+(x+2)");
assertDerivAndEval1("(x+1)-(x+2)");
assertDerivAndEval1("-(x+1)+(x+2)");
assertDerivAndEval1("-(x+1)-(x+2)");
assertDerivAndEval1("(x+1)*(x+2)");
deriv1("(x+1)/(x+2)").derive().eval(0).withHessian().derive(Output.IgnoreWhitespace).eval(0);
assertDerivAndEval1("(x+1)*(x+2*(x-1))");
assertDerivAndEval1(Ignored.OutputFormatting, "(x+1)^(x+2)");
}
@Test
public void testDeriveFunctions1() {
deriv1("log(x)").derive().eval(0).withHessian().derive(Ignored.OutputFormatting).eval(0).eval(1).eval(Ignored.MissingWarning,
-1);
assertDerivAndEval1("exp(x)");
assertDerivAndEval1("cos(x)");
assertDerivAndEval1("sin(x)");
assertDerivAndEval1("tan(x)");
assertDerivAndEval1("cosh(x)");
assertDerivAndEval1("sinh(x)");
deriv1("tanh(x)").derive().eval(0).withHessian().derive(Ignored.OutputFormatting).eval(0).eval(1).eval(-1);
assertDerivAndEval1("sqrt(x)");
deriv1("pnorm(x)").derive().eval(0).withHessian().derive(Ignored.OutputFormatting).eval(0);
assertDerivAndEval1(Ignored.OutputFormatting, "dnorm(x)");
assertDerivAndEval1("asin(x)");
assertDerivAndEval1(Ignored.OutputFormatting, "acos(x)");
deriv1("atan(x)").derive().eval(0).withHessian().derive(Ignored.OutputFormatting).eval(0);
assertDeriv1("gamma(x)").eval(Ignored.Unimplemented, 0);
assertDeriv1("lgamma(x)").eval(0.5);
assertDeriv1("digamma(x)").eval(Ignored.Unimplemented, 0);
assertDeriv1("trigamma(x)").eval(Ignored.Unimplemented, 0);
assertDeriv1("psigamma(x)").eval(Ignored.Unimplemented, 0);
}
@Test
public void testDeriveFunctionsWithCompArg1() {
deriv1("log(2*x)").derive().eval(0).withHessian().derive(Ignored.OutputFormatting).eval(0);
deriv1("log(sin(2*x))").derive().eval(0).withHessian().derive(Output.IgnoreWhitespace).eval(0);
assertDerivAndEval1(Output.IgnoreWhitespace, "log(sin(2*x)*cos(x^2))");
assertDerivAndEval1(Output.IgnoreWhitespace, "pnorm(sin(2*x)^log(x+1))");
}
@Test
public void testDeriveBasicExpressions2() {
assertDerivAndEval2("x + y");
deriv2("x*y").derive().eval(0, 0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
deriv2("2*x*y").derive().eval(0, 0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
deriv2("x/y/2").derive(Ignored.OutputFormatting).eval(0,
0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
deriv2("2/x*y").derive(Ignored.OutputFormatting).eval(0,
0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
deriv2("x^y").derive(Ignored.OutputFormatting).eval(0,
0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
deriv2("(x+1)*(y+2)").derive().eval(0, 0).withHessian().derive(Ignored.OutputFormatting).eval(0,
0);
assertDerivAndEval2("(x+1)-(y+2)");
deriv2("-(x+1)+(y+2)").derive().eval(0, 0).withHessian().derive(Ignored.OutputFormatting).eval(0,
0);
deriv2("-(x+1)-(y+2)").derive().eval(0, 0).withHessian().derive(Ignored.OutputFormatting).eval(0,
0);
deriv2("(x+1)/(y+2)").derive(Ignored.OutputFormatting).eval(0,
0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
deriv2("(x+1)*(y+2*(x-1))").derive().eval(0,
0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
deriv2("(x+1)^(y+2)").derive().eval(0, 0).withHessian().derive(Ignored.OutputFormatting).eval(0,
0).eval(1, 1);
}
@Test
public void testLongExpression() {
deriv2("(log(2*x)+sin(x))*cos(y^x*(exp(x)))*(x*y+x^y/(x+y+1))").derive(Output.IgnoreWhitespace).eval(0, 0).withHessian().derive(Ignored.OutputFormatting).eval(0, 0);
}
@Test
public void testFunctionGenereration() {
assertEval(Output.IgnoreWhitespace, "(df <- deriv(~x^2*sin(x), \"x\", function.arg=TRUE));df(0)");
assertEval(Output.IgnoreWhitespace, "(df <- deriv(~x^2*sin(x), \"x\", function.arg=c(\"x\")));df(0)");
assertEval(Output.IgnoreWhitespace, "(df <- deriv(~x^2*sin(x), \"x\", function.arg=function(x=1){}));df(0)");
}
@Test
public void testUnusualExprs() {
assertEval("(df <- deriv(expression(x^2*sin(x)), \"x\"));df(0)");
assertEval("(df <- deriv(quote(x^2*sin(x)), \"x\"));df(0)");
assertEval("g<-quote(x^2);(df <- deriv(g, \"x\"));df(0)");
assertEval("(df <- deriv(1, \"x\"));df(0)");
assertEval("x<-1;(df <- deriv(x, \"x\"));df(0)");
}
}
......@@ -764,3 +764,5 @@ com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/p
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/printer/ComplexVectorPrinter.java,gnu_r_gentleman_ihaka2.copyright
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/printer/LogicalVectorPrinter.java,gnu_r_gentleman_ihaka2.copyright
com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/printer/PairListPrinter.java,gnu_r_gentleman_ihaka2.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/deriv/Deriv.java,gnu_r_gentleman_ihaka2.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/deriv/DerivVisitor.java,gnu_r_gentleman_ihaka2.copyright
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment