diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Arg.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Arg.java deleted file mode 100644 index 130ec5a77e8d0e2344b0ff675af560e2908df99f..0000000000000000000000000000000000000000 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Arg.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package com.oracle.truffle.r.nodes.builtin.base; - -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; -import static com.oracle.truffle.r.runtime.RDispatch.COMPLEX_GROUP_GENERIC; -import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; -import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; - -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticBuiltinNode; -import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RType; -import com.oracle.truffle.r.runtime.builtins.RBuiltin; - -@RBuiltin(name = "Arg", kind = PRIMITIVE, parameterNames = {"z"}, dispatch = COMPLEX_GROUP_GENERIC, behavior = PURE) -public abstract class Arg extends UnaryArithmeticBuiltinNode { - - public Arg() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } - - static { - Casts casts = new Casts(Arg.class); - casts.arg("z").mustBe(numericValue().or(complexValue()), RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION); - } - - @Override - public RType calculateResultType(RType argumentType) { - switch (argumentType) { - case Complex: - return RType.Double; - default: - return super.calculateResultType(argumentType); - } - } - - @Override - public int op(byte op) { - return 0; - } - - @Override - public int op(int op) { - return op; - } - - @Override - public double op(double op) { - if (op >= 0) { - return 0; - } else { - return Math.PI; - } - } - - @Override - public double opd(double re, double im) { - return Math.atan2(im, re); - } - - @Specialization - @Override - public Object calculateUnboxed(Object op) { - return super.calculateUnboxed(op); - } -} diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java index 237a24e80ab60e7c692c0ca282405dedac39c250..5177cdd41d46415ce91e5425ec5ac21077b763aa 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java @@ -124,6 +124,8 @@ import com.oracle.truffle.r.nodes.builtin.fastr.FastRTry; import com.oracle.truffle.r.nodes.builtin.fastr.FastRTryNodeGen; import com.oracle.truffle.r.nodes.builtin.fastr.FastrDqrls; import com.oracle.truffle.r.nodes.builtin.fastr.FastrDqrlsNodeGen; +import com.oracle.truffle.r.nodes.unary.UnaryArithmeticBuiltinNode; +import com.oracle.truffle.r.nodes.unary.UnaryArithmeticSpecial; import com.oracle.truffle.r.nodes.unary.UnaryNotNode; import com.oracle.truffle.r.nodes.unary.UnaryNotNodeGen; import com.oracle.truffle.r.runtime.RVisibility; @@ -152,6 +154,38 @@ public class BasePackage extends RBuiltinPackage { */ add(UnaryNotNode.class, UnaryNotNodeGen::create); + addUnaryArithmetic(Ceiling.class, Ceiling::new); + addUnaryArithmetic(Floor.class, Floor::new); + addUnaryArithmetic(Trunc.class, Trunc::new); + addUnaryArithmetic(LogFunctions.Log10.class, LogFunctions.Log10::new); + addUnaryArithmetic(LogFunctions.Log1p.class, LogFunctions.Log1p::new); + addUnaryArithmetic(LogFunctions.Log2.class, LogFunctions.Log2::new); + addUnaryArithmetic(NumericalFunctions.Abs.class, NumericalFunctions.Abs::new); + addUnaryArithmetic(NumericalFunctions.Arg.class, NumericalFunctions.Arg::new); + addUnaryArithmetic(NumericalFunctions.Conj.class, NumericalFunctions.Conj::new); + addUnaryArithmetic(NumericalFunctions.Im.class, NumericalFunctions.Im::new); + addUnaryArithmetic(NumericalFunctions.Mod.class, NumericalFunctions.Mod::new); + addUnaryArithmetic(NumericalFunctions.Re.class, NumericalFunctions.Re::new); + addUnaryArithmetic(NumericalFunctions.Sign.class, NumericalFunctions.Sign::new); + addUnaryArithmetic(NumericalFunctions.Sqrt.class, NumericalFunctions.Sqrt::new); + addUnaryArithmetic(TrigExpFunctions.Acos.class, TrigExpFunctions.Acos::new); + addUnaryArithmetic(TrigExpFunctions.Acosh.class, TrigExpFunctions.Acosh::new); + addUnaryArithmetic(TrigExpFunctions.Asin.class, TrigExpFunctions.Asin::new); + addUnaryArithmetic(TrigExpFunctions.Asinh.class, TrigExpFunctions.Asinh::new); + addUnaryArithmetic(TrigExpFunctions.Atan.class, TrigExpFunctions.Atan::new); + addUnaryArithmetic(TrigExpFunctions.Atanh.class, TrigExpFunctions.Atanh::new); + addUnaryArithmetic(TrigExpFunctions.Cos.class, TrigExpFunctions.Cos::new); + addUnaryArithmetic(TrigExpFunctions.Cosh.class, TrigExpFunctions.Cosh::new); + addUnaryArithmetic(TrigExpFunctions.Cospi.class, TrigExpFunctions.Cospi::new); + addUnaryArithmetic(TrigExpFunctions.Exp.class, TrigExpFunctions.Exp::new); + addUnaryArithmetic(TrigExpFunctions.ExpM1.class, TrigExpFunctions.ExpM1::new); + addUnaryArithmetic(TrigExpFunctions.Sin.class, TrigExpFunctions.Sin::new); + addUnaryArithmetic(TrigExpFunctions.Sinh.class, TrigExpFunctions.Sinh::new); + addUnaryArithmetic(TrigExpFunctions.Sinpi.class, TrigExpFunctions.Sinpi::new); + addUnaryArithmetic(TrigExpFunctions.Tan.class, TrigExpFunctions.Tan::new); + addUnaryArithmetic(TrigExpFunctions.Tanh.class, TrigExpFunctions.Tanh::new); + addUnaryArithmetic(TrigExpFunctions.Tanpi.class, TrigExpFunctions.Tanpi::new); + addBinaryArithmetic(BinaryArithmetic.AddBuiltin.class, BinaryArithmetic.ADD, UnaryArithmetic.PLUS); addBinaryArithmetic(BinaryArithmetic.SubtractBuiltin.class, BinaryArithmetic.SUBTRACT, UnaryArithmetic.NEGATE); addBinaryArithmetic(BinaryArithmetic.DivBuiltin.class, BinaryArithmetic.DIV, null); @@ -176,12 +210,10 @@ public class BasePackage extends RBuiltinPackage { // Now load the rest of the builtins in "base" add(Abbrev.class, AbbrevNodeGen::create); add(APerm.class, APermNodeGen::create); - add(NumericalFunctions.Abs.class, NumericalFunctionsFactory.AbsNodeGen::create); add(All.class, AllNodeGen::create); add(AllNames.class, AllNamesNodeGen::create); add(Any.class, AnyNodeGen::create); add(AnyNA.class, AnyNANodeGen::create); - add(Arg.class, ArgNodeGen::create); add(Args.class, ArgsNodeGen::create); add(Array.class, ArrayNodeGen::create); add(AsCall.class, AsCallNodeGen::create); @@ -228,7 +260,6 @@ public class BasePackage extends RBuiltinPackage { add(CallAndExternalFunctions.DotExternalGraphics.class, CallAndExternalFunctionsFactory.DotExternalGraphicsNodeGen::create); add(Capabilities.class, CapabilitiesNodeGen::create); add(Cat.class, CatNodeGen::create); - add(Ceiling.class, CeilingNodeGen::create); add(CharMatch.class, CharMatchNodeGen::create); add(Col.class, ColNodeGen::create); add(Colon.class, ColonNodeGen::create, Colon::special); @@ -239,7 +270,6 @@ public class BasePackage extends RBuiltinPackage { add(Complex.class, ComplexNodeGen::create); add(CompileFunctions.CompilePKGS.class, CompileFunctionsFactory.CompilePKGSNodeGen::create); add(CompileFunctions.EnableJIT.class, CompileFunctionsFactory.EnableJITNodeGen::create); - add(NumericalFunctions.Conj.class, NumericalFunctionsFactory.ConjNodeGen::create); add(ConditionFunctions.AddCondHands.class, ConditionFunctionsFactory.AddCondHandsNodeGen::create); add(ConditionFunctions.AddRestart.class, ConditionFunctionsFactory.AddRestartNodeGen::create); add(ConditionFunctions.DfltStop.class, ConditionFunctionsFactory.DfltStopNodeGen::create); @@ -421,7 +451,6 @@ public class BasePackage extends RBuiltinPackage { add(FileFunctions.ListFiles.class, FileFunctionsFactory.ListFilesNodeGen::create); add(FileFunctions.ListDirs.class, FileFunctionsFactory.ListDirsNodeGen::create); add(FileFunctions.Unlink.class, FileFunctionsFactory.UnlinkNodeGen::create); - add(Floor.class, FloorNodeGen::create); add(ForceAndCall.class, ForceAndCallNodeGen::create); add(Formals.class, FormalsNodeGen::create); add(Format.class, FormatNodeGen::create); @@ -466,7 +495,6 @@ public class BasePackage extends RBuiltinPackage { add(HiddenInternalFunctions.LazyLoadDBinsertValue.class, HiddenInternalFunctionsFactory.LazyLoadDBinsertValueNodeGen::create); add(IConv.class, IConvNodeGen::create); add(Identical.class, Identical::create); - add(NumericalFunctions.Im.class, NumericalFunctionsFactory.ImNodeGen::create); add(InheritsBuiltin.class, InheritsBuiltinNodeGen::create); add(Interactive.class, InteractiveNodeGen::create); add(Internal.class, InternalNodeGen::create); @@ -529,9 +557,6 @@ public class BasePackage extends RBuiltinPackage { add(LocaleFunctions.LocaleConv.class, LocaleFunctionsFactory.LocaleConvNodeGen::create); add(LocaleFunctions.SetLocale.class, LocaleFunctionsFactory.SetLocaleNodeGen::create); add(LogFunctions.Log.class, LogFunctionsFactory.LogNodeGen::create); - add(LogFunctions.Log10.class, LogFunctionsFactory.Log10NodeGen::create); - add(LogFunctions.Log1p.class, LogFunctionsFactory.Log1pNodeGen::create); - add(LogFunctions.Log2.class, LogFunctionsFactory.Log2NodeGen::create); add(Ls.class, LsNodeGen::create); add(MakeNames.class, MakeNamesNodeGen::create); add(MakeUnique.class, MakeUniqueNodeGen::create); @@ -546,7 +571,6 @@ public class BasePackage extends RBuiltinPackage { add(Merge.class, MergeNodeGen::create); add(Min.class, MinNodeGen::create); add(Missing.class, MissingNodeGen::create); - add(NumericalFunctions.Mod.class, NumericalFunctionsFactory.ModNodeGen::create); add(NArgs.class, NArgsNodeGen::create); add(NChar.class, NCharNodeGen::create); add(NGetText.class, NGetTextNodeGen::create); @@ -586,7 +610,6 @@ public class BasePackage extends RBuiltinPackage { add(RawFunctions.RawToChar.class, RawFunctionsFactory.RawToCharNodeGen::create); add(RawFunctions.RawShift.class, RawFunctionsFactory.RawShiftNodeGen::create); add(RawToBits.class, RawToBitsNodeGen::create); - add(NumericalFunctions.Re.class, NumericalFunctionsFactory.ReNodeGen::create); add(ReadDCF.class, ReadDCFNodeGen::create); add(ReadREnviron.class, ReadREnvironNodeGen::create); add(Readline.class, ReadlineNodeGen::create); @@ -618,7 +641,6 @@ public class BasePackage extends RBuiltinPackage { add(SerializeFunctions.UnserializeFromConn.class, SerializeFunctionsFactory.UnserializeFromConnNodeGen::create); add(Setwd.class, SetwdNodeGen::create); add(ShortRowNames.class, ShortRowNamesNodeGen::create); - add(NumericalFunctions.Sign.class, NumericalFunctionsFactory.SignNodeGen::create); add(Signif.class, SignifNodeGen::create); add(SinkFunctions.Sink.class, SinkFunctionsFactory.SinkNodeGen::create); add(SinkFunctions.SinkNumber.class, SinkFunctionsFactory.SinkNumberNodeGen::create); @@ -629,7 +651,6 @@ public class BasePackage extends RBuiltinPackage { add(SortFunctions.Sort.class, SortFunctionsFactory.SortNodeGen::create); add(Split.class, SplitNodeGen::create); add(Sprintf.class, SprintfNodeGen::create); - add(NumericalFunctions.Sqrt.class, NumericalFunctionsFactory.SqrtNodeGen::create); add(StandardGeneric.class, StandardGenericNodeGen::create); add(StartsEndsWithFunctions.StartsWith.class, StartsEndsWithFunctionsFactory.StartsWithNodeGen::create); add(StartsEndsWithFunctions.EndsWith.class, StartsEndsWithFunctionsFactory.EndsWithNodeGen::create); @@ -667,25 +688,7 @@ public class BasePackage extends RBuiltinPackage { add(TraceFunctions.Retracemem.class, TraceFunctionsFactory.RetracememNodeGen::create); add(TraceFunctions.Untracemem.class, TraceFunctionsFactory.UntracememNodeGen::create); add(Transpose.class, TransposeNodeGen::create); - add(TrigExpFunctions.Acos.class, TrigExpFunctionsFactory.AcosNodeGen::create); - add(TrigExpFunctions.Acosh.class, TrigExpFunctionsFactory.AcoshNodeGen::create); - add(TrigExpFunctions.Asin.class, TrigExpFunctionsFactory.AsinNodeGen::create); - add(TrigExpFunctions.Asinh.class, TrigExpFunctionsFactory.AsinhNodeGen::create); - add(TrigExpFunctions.Atan.class, TrigExpFunctionsFactory.AtanNodeGen::create); add(TrigExpFunctions.Atan2.class, TrigExpFunctionsFactory.Atan2NodeGen::create); - add(TrigExpFunctions.Atanh.class, TrigExpFunctionsFactory.AtanhNodeGen::create); - add(TrigExpFunctions.Cos.class, TrigExpFunctionsFactory.CosNodeGen::create); - add(TrigExpFunctions.Cosh.class, TrigExpFunctionsFactory.CoshNodeGen::create); - add(TrigExpFunctions.Cospi.class, TrigExpFunctionsFactory.CospiNodeGen::create); - add(TrigExpFunctions.Exp.class, TrigExpFunctionsFactory.ExpNodeGen::create); - add(TrigExpFunctions.ExpM1.class, TrigExpFunctionsFactory.ExpM1NodeGen::create); - add(TrigExpFunctions.Sin.class, TrigExpFunctionsFactory.SinNodeGen::create); - add(TrigExpFunctions.Sinh.class, TrigExpFunctionsFactory.SinhNodeGen::create); - add(TrigExpFunctions.Sinpi.class, TrigExpFunctionsFactory.SinpiNodeGen::create); - add(TrigExpFunctions.Tan.class, TrigExpFunctionsFactory.TanNodeGen::create); - add(TrigExpFunctions.Tanh.class, TrigExpFunctionsFactory.TanhNodeGen::create); - add(TrigExpFunctions.Tanpi.class, TrigExpFunctionsFactory.TanpiNodeGen::create); - add(Trunc.class, TruncNodeGen::create); add(Typeof.class, TypeofNodeGen::create); add(UnClass.class, UnClassNodeGen::create); add(Unique.class, UniqueNodeGen::create); @@ -745,7 +748,11 @@ public class BasePackage extends RBuiltinPackage { } private void addBinaryArithmetic(Class<?> builtinClass, BinaryArithmeticFactory binaryFactory, UnaryArithmeticFactory unaryFactory) { - add(builtinClass, () -> BinaryArithmeticNodeGen.create(binaryFactory, unaryFactory), BinaryArithmeticSpecial.createSpecialFactory(binaryFactory)); + add(builtinClass, () -> BinaryArithmeticNodeGen.create(binaryFactory, unaryFactory), BinaryArithmeticSpecial.createSpecialFactory(binaryFactory, unaryFactory)); + } + + private void addUnaryArithmetic(Class<?> builtinClass, UnaryArithmeticFactory unaryFactory) { + add(builtinClass, () -> new UnaryArithmeticBuiltinNode(unaryFactory), UnaryArithmeticSpecial.createSpecialFactory(unaryFactory)); } private void addBinaryCompare(Class<?> builtinClass, BooleanOperationFactory factory) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Ceiling.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Ceiling.java index 47d7419cb8ab1a5b233dba9615c0a07ac2ececf9..4590c44a69bafa50113821feb55ec36a55510f2f 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Ceiling.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Ceiling.java @@ -22,64 +22,21 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.runtime.RDispatch.MATH_GROUP_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_ARITHMETIC; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticBuiltinNode; -import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; @RBuiltin(name = "ceiling", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) -public abstract class Ceiling extends UnaryArithmeticBuiltinNode { +public final class Ceiling extends UnaryArithmetic { - public static final UnaryArithmeticFactory CEILING = FloorNodeGen.create(); - - public Ceiling() { - super(RType.Double, RError.Message.NON_NUMERIC_MATH, null); - } - - static { - Casts casts = new Casts(Ceiling.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_MATH).mustNotBeNull().mustBe(complexValue().not(), RError.Message.UNIMPLEMENTED_COMPLEX_FUN).mustBe(numericValue()).asDoubleVector(true, - true, true); - } - - @Override - public int op(byte op) { - return op; - } - - @Override - public int op(int op) { - return op; - } + public static final UnaryArithmeticFactory CEILING = Ceiling::new; @Override public double op(double op) { return Math.ceil(op); } - - @Override - protected double opd(double re, double im) { - return op(re); - } - - @Override - public RComplex op(double re, double im) { - return RDataFactory.createComplex(op(re), op(im)); - } - - @Specialization - @Override - public Object calculateUnboxed(Object op) { - return super.calculateUnboxed(op); - } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Floor.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Floor.java index 180e8af47caf92acec9b4c2206a8268cd9958820..9f531ce897be5d47cfc479fbf748f1b2cbde200e 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Floor.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Floor.java @@ -22,64 +22,21 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.runtime.RDispatch.MATH_GROUP_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_ARITHMETIC; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticBuiltinNode; -import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; @RBuiltin(name = "floor", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) -public abstract class Floor extends UnaryArithmeticBuiltinNode { +public final class Floor extends UnaryArithmetic { - public static final UnaryArithmeticFactory FLOOR = FloorNodeGen.create(); - - public Floor() { - super(RType.Double, RError.Message.NON_NUMERIC_MATH, null); - } - - static { - Casts casts = new Casts(Floor.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_MATH).mustNotBeNull().mustBe(complexValue().not(), RError.Message.UNIMPLEMENTED_COMPLEX_FUN).mustBe(numericValue()).asDoubleVector(true, - true, true); - } - - @Override - public int op(byte op) { - return op; - } - - @Override - public int op(int op) { - return op; - } + public static final UnaryArithmeticFactory FLOOR = Floor::new; @Override public double op(double op) { return Math.floor(op); } - - @Override - protected double opd(double re, double im) { - return op(re); - } - - @Override - public RComplex op(double re, double im) { - return RDataFactory.createComplex(op(re), op(im)); - } - - @Specialization - @Override - public Object calculateUnboxed(Object op) { - return super.calculateUnboxed(op); - } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java index 4a3a38d7134f8499c65c3460719efd9cdac0439e..09bde8d617b02030f4f908522f51447acb1d3173 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LogFunctions.java @@ -42,7 +42,6 @@ import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNames import com.oracle.truffle.r.nodes.binary.BinaryMapArithmeticFunctionNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticBuiltinNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RType; @@ -57,6 +56,7 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.nodes.RBaseNode; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.na.NACheck; import com.oracle.truffle.r.runtime.ops.na.NAProfile; @@ -348,19 +348,10 @@ public class LogFunctions { } @RBuiltin(name = "log10", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) - public abstract static class Log10 extends UnaryArithmeticBuiltinNode { - - public Log10() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } + public static final class Log10 extends UnaryArithmetic { private static final double LOG_10 = Math.log(10); - static { - Casts casts = new Casts(Log10.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION).mustBe(numericValue().or(complexValue())); - } - @Override public double op(double op) { return Math.log10(op); @@ -375,19 +366,10 @@ public class LogFunctions { } @RBuiltin(name = "log2", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) - public abstract static class Log2 extends UnaryArithmeticBuiltinNode { - - public Log2() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } + public static final class Log2 extends UnaryArithmetic { private static final double LOG_2 = Math.log(2); - static { - Casts casts = new Casts(Log2.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION).mustBe(numericValue().or(complexValue())); - } - @Override public double op(double op) { return Math.log(op) / LOG_2; @@ -402,30 +384,11 @@ public class LogFunctions { } @RBuiltin(name = "log1p", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) - public abstract static class Log1p extends UnaryArithmeticBuiltinNode { - - public Log1p() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } - - static { - Casts casts = new Casts(Log1p.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION).mustBe(numericValue().or(complexValue())); - } - - @Override - public int op(byte op) { - throw new UnsupportedOperationException(); - } - - @Override - public int op(int op) { - throw new UnsupportedOperationException(); - } + public static final class Log1p extends UnaryArithmetic { @Override public double op(double op) { - return Math.log(1 + op); + return Math.log1p(op); } @Override diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NumericalFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NumericalFunctions.java index c540de96876e1e0f56e5c95b70328fa1988ebc5d..9bcd735575be268d51ba6daf95ed6597db6f8e5e 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NumericalFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NumericalFunctions.java @@ -22,58 +22,32 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.runtime.RDispatch.COMPLEX_GROUP_GENERIC; import static com.oracle.truffle.r.runtime.RDispatch.MATH_GROUP_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticBuiltinNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RComplex; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; public class NumericalFunctions { - /** - * This node is only a workaround that makes the annotation processor process the other inner - * node classes. These classes would be ignored otherwise, since they do not contain any - * specialization, which would trigger the code generation performed by the annotation - * processor. - */ - public abstract static class DummyNode extends RBuiltinNode.Arg1 { - - @Specialization - protected Object dummySpec(@SuppressWarnings("unused") Object value) { - return null; - } - } - @RBuiltin(name = "abs", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) - public abstract static class Abs extends UnaryArithmeticBuiltinNode { + public static final class Abs extends UnaryArithmetic { - public Abs() { - super(RType.Integer, RError.Message.NON_NUMERIC_MATH, null); - } - - static { - Casts casts = new Casts(Abs.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(numericValue().or(complexValue())); + @Override + public RType calculateResultType(RType argumentType) { + return argumentType == RType.Complex ? RType.Double : argumentType; } @Override - public RType calculateResultType(RType argumentType) { - switch (argumentType) { - case Complex: - return RType.Double; - default: - return super.calculateResultType(argumentType); - } + public RType getMinPrecedence() { + return RType.Integer; } @Override @@ -98,25 +72,16 @@ public class NumericalFunctions { } @RBuiltin(name = "Re", kind = PRIMITIVE, parameterNames = {"z"}, dispatch = COMPLEX_GROUP_GENERIC, behavior = PURE) - public abstract static class Re extends UnaryArithmeticBuiltinNode { - - public Re() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } + public static final class Re extends UnaryArithmetic { - static { - Casts casts = new Casts(Re.class); - casts.arg("z").defaultError(RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION).mustBe(numericValue().or(complexValue())); + @Override + public RType calculateResultType(RType argumentType) { + return argumentType == RType.Complex ? RType.Double : argumentType; } @Override - public RType calculateResultType(RType argumentType) { - switch (argumentType) { - case Complex: - return RType.Double; - default: - return super.calculateResultType(argumentType); - } + public Message getArgumentError() { + return RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION; } @Override @@ -138,28 +103,20 @@ public class NumericalFunctions { public double opd(double re, double im) { return re; } + } @RBuiltin(name = "Im", kind = PRIMITIVE, parameterNames = {"z"}, dispatch = COMPLEX_GROUP_GENERIC, behavior = PURE) - public abstract static class Im extends UnaryArithmeticBuiltinNode { - - public Im() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } + public static final class Im extends UnaryArithmetic { - static { - Casts casts = new Casts(Im.class); - casts.arg("z").defaultError(RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION).mustBe(numericValue().or(complexValue())); + @Override + public RType calculateResultType(RType argumentType) { + return argumentType == RType.Complex ? RType.Double : argumentType; } @Override - public RType calculateResultType(RType argumentType) { - switch (argumentType) { - case Complex: - return RType.Double; - default: - return super.calculateResultType(argumentType); - } + public Message getArgumentError() { + return RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION; } @Override @@ -184,15 +141,11 @@ public class NumericalFunctions { } @RBuiltin(name = "Conj", kind = PRIMITIVE, parameterNames = {"z"}, dispatch = COMPLEX_GROUP_GENERIC, behavior = PURE) - public abstract static class Conj extends UnaryArithmeticBuiltinNode { - - public Conj() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } + public static final class Conj extends UnaryArithmetic { - static { - Casts casts = new Casts(Conj.class); - casts.arg("z").defaultError(RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION).mustBe(numericValue().or(complexValue())); + @Override + public Message getArgumentError() { + return RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION; } @Override @@ -217,25 +170,16 @@ public class NumericalFunctions { } @RBuiltin(name = "Mod", kind = PRIMITIVE, parameterNames = {"z"}, dispatch = COMPLEX_GROUP_GENERIC, behavior = PURE) - public abstract static class Mod extends UnaryArithmeticBuiltinNode { - - public Mod() { - super(RType.Double, RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION, null); - } + public static final class Mod extends UnaryArithmetic { - static { - Casts casts = new Casts(Mod.class); - casts.arg("z").defaultError(RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION).mustBe(numericValue().or(complexValue())); + @Override + public RType calculateResultType(RType argumentType) { + return argumentType == RType.Complex ? RType.Double : argumentType; } @Override - public RType calculateResultType(RType argumentType) { - switch (argumentType) { - case Complex: - return RType.Double; - default: - return super.calculateResultType(argumentType); - } + public Message getArgumentError() { + return RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION; } @Override @@ -259,18 +203,42 @@ public class NumericalFunctions { } } - @RBuiltin(name = "sign", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) - public abstract static class Sign extends UnaryArithmeticBuiltinNode { + @RBuiltin(name = "Arg", kind = PRIMITIVE, parameterNames = {"z"}, dispatch = COMPLEX_GROUP_GENERIC, behavior = PURE) + public static final class Arg extends UnaryArithmetic { + + @Override + public RType calculateResultType(RType argumentType) { + return argumentType == RType.Complex ? RType.Double : argumentType; + } + + @Override + public int op(byte op) { + return 0; + } - public Sign() { - super(RType.Double, RError.Message.NON_NUMERIC_MATH, null); + @Override + public int op(int op) { + return op; } - static { - Casts casts = new Casts(Sign.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_MATH).mustBe(complexValue().not(), RError.Message.UNIMPLEMENTED_COMPLEX_FUN).mustBe(numericValue()); + @Override + public double op(double op) { + if (op >= 0) { + return 0; + } else { + return Math.PI; + } } + @Override + public double opd(double re, double im) { + return Math.atan2(im, re); + } + } + + @RBuiltin(name = "sign", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) + public static final class Sign extends UnaryArithmetic { + @Override public int op(byte op) { return op == RRuntime.LOGICAL_TRUE ? 1 : 0; @@ -285,19 +253,15 @@ public class NumericalFunctions { public double op(double op) { return Math.signum(op); } - } - @RBuiltin(name = "sqrt", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) - public abstract static class Sqrt extends UnaryArithmeticBuiltinNode { - - public Sqrt() { - super(RType.Double, RError.Message.NON_NUMERIC_MATH, null); + @Override + public RComplex op(double re, double im) { + throw error(Message.UNIMPLEMENTED_COMPLEX_FUN); } + } - static { - Casts casts = new Casts(Sqrt.class); - casts.arg("x").defaultError(RError.Message.UNIMPLEMENTED_COMPLEX_FUN).mustBe(numericValue().or(complexValue())); - } + @RBuiltin(name = "sqrt", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) + public static final class Sqrt extends UnaryArithmetic { @Override public int op(byte op) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java index 832df443259d101e311724243bbb40c2558ecd9c..ec058295772e1fd4df8aa0c514bcf36fe3c7086f 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Round.java @@ -30,6 +30,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; @@ -44,15 +45,11 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; -import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; -import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; import com.oracle.truffle.r.runtime.ops.na.NACheck; @RBuiltin(name = "round", kind = PRIMITIVE, parameterNames = {"x", "digits"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) public abstract class Round extends RBuiltinNode.Arg2 { - public static final UnaryArithmeticFactory ROUND = RoundArithmetic::new; - @Child private RoundArithmetic roundOp = new RoundArithmetic(); private final NACheck check = NACheck.create(); @@ -146,7 +143,7 @@ public abstract class Round extends RBuiltinNode.Arg2 { @Specialization(guards = "digits == 0") protected RComplex round(RComplex x, @SuppressWarnings("unused") int digits) { check.enable(x); - return check.check(x) ? RComplex.createNA() : roundOp.op(x.getRealPart(), x.getImaginaryPart()); + return check.check(x) ? RComplex.createNA() : RComplex.valueOf(roundOp.op(x.getRealPart()), roundOp.op(x.getImaginaryPart())); } @Specialization(guards = "digits != 0") @@ -187,30 +184,25 @@ public abstract class Round extends RBuiltinNode.Arg2 { return ret; } - public static class RoundArithmetic extends UnaryArithmetic { + public static final class RoundArithmetic extends Node { @Child private BinaryArithmetic pow; - @Override + @SuppressWarnings("static-method") public int op(int op) { return op; } - @Override + @SuppressWarnings("static-method") public double op(double op) { return Math.rint(op); } - @Override + @SuppressWarnings("static-method") public int op(byte op) { return op; } - @Override - public RComplex op(double re, double im) { - return RDataFactory.createComplex(op(re), op(im)); - } - public double opd(double op, int digits) { return fround(op, digits); } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TrigExpFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TrigExpFunctions.java index ccf9d75577c560dc9b136ea61c1b000526d7705e..8ba21984c778931efe59e9c2e003f5dd55a3c4eb 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TrigExpFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/TrigExpFunctions.java @@ -36,17 +36,10 @@ import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.LoopConditionProfile; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.nodes.builtin.base.TrigExpFunctionsFactory.AcosNodeGen; -import com.oracle.truffle.r.nodes.builtin.base.TrigExpFunctionsFactory.AsinNodeGen; -import com.oracle.truffle.r.nodes.builtin.base.TrigExpFunctionsFactory.AtanNodeGen; -import com.oracle.truffle.r.nodes.builtin.base.TrigExpFunctionsFactory.CosNodeGen; -import com.oracle.truffle.r.nodes.builtin.base.TrigExpFunctionsFactory.SinNodeGen; -import com.oracle.truffle.r.nodes.builtin.base.TrigExpFunctionsFactory.TanNodeGen; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticBuiltinNode; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; -import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -55,17 +48,13 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic.Pow.CHypot; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.na.NACheck; public class TrigExpFunctions { @RBuiltin(name = "exp", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Exp extends UnaryArithmeticBuiltinNode { - - public Exp() { - super(RType.Double); - } - + public static final class Exp extends UnaryArithmetic { @Child private BinaryArithmetic calculatePowNode; @Override @@ -84,11 +73,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "expm1", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class ExpM1 extends UnaryArithmeticBuiltinNode { - - public ExpM1() { - super(RType.Double); - } + public static final class ExpM1 extends UnaryArithmetic { @Child private BinaryArithmetic calculatePowNode; @@ -109,11 +94,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "sin", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Sin extends UnaryArithmeticBuiltinNode { - - public Sin() { - super(RType.Double); - } + public static final class Sin extends UnaryArithmetic { @Override public double op(double op) { @@ -129,11 +110,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "sinh", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Sinh extends UnaryArithmeticBuiltinNode { - - public Sinh() { - super(RType.Double); - } + public static final class Sinh extends UnaryArithmetic { @Override public double op(double op) { @@ -149,12 +126,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "sinpi", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Sinpi extends UnaryArithmeticBuiltinNode { - - public Sinpi() { - super(RType.Double); - } - + public static final class Sinpi extends UnaryArithmetic { @Override public double op(double op) { double norm = op % 2d; @@ -169,15 +141,15 @@ public class TrigExpFunctions { } return Math.sin(norm * Math.PI); } - } - @RBuiltin(name = "cos", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Cos extends UnaryArithmeticBuiltinNode { - - public Cos() { - super(RType.Double); + @Override + public RComplex op(double re, double im) { + throw error(Message.UNIMPLEMENTED_COMPLEX_FUN); } + } + @RBuiltin(name = "cos", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) + public static final class Cos extends UnaryArithmetic { @Override public double op(double op) { return Math.cos(op); @@ -192,12 +164,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "cosh", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Cosh extends UnaryArithmeticBuiltinNode { - - public Cosh() { - super(RType.Double); - } - + public static final class Cosh extends UnaryArithmetic { @Override public double op(double op) { return Math.cosh(op); @@ -212,12 +179,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "cospi", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Cospi extends UnaryArithmeticBuiltinNode { - - public Cospi() { - super(RType.Double); - } - + public static final class Cospi extends UnaryArithmetic { @Override public double op(double op) { double norm = op % 2d; @@ -240,14 +202,9 @@ public class TrigExpFunctions { } @RBuiltin(name = "tan", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Tan extends UnaryArithmeticBuiltinNode { - - public Tan() { - super(RType.Double); - } - - @Child private Sin sinNode = SinNodeGen.create(); - @Child private Cos cosNode = CosNodeGen.create(); + public static final class Tan extends UnaryArithmetic { + @Child private Sin sinNode = new Sin(); + @Child private Cos cosNode = new Cos(); @Override public double op(double op) { @@ -266,13 +223,8 @@ public class TrigExpFunctions { } @RBuiltin(name = "tanh", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Tanh extends UnaryArithmeticBuiltinNode { - - public Tanh() { - super(RType.Double); - } - - @Child private Tan tanNode = TanNodeGen.create(); + public static final class Tanh extends UnaryArithmetic { + @Child private Tan tanNode = new Tan(); @Override public double op(double op) { @@ -287,12 +239,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "tanpi", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Tanpi extends UnaryArithmeticBuiltinNode { - - public Tanpi() { - super(RType.Double); - } - + public static final class Tanpi extends UnaryArithmetic { @Override public double op(double op) { double norm = op % 1d; @@ -312,14 +259,10 @@ public class TrigExpFunctions { } @RBuiltin(name = "asin", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Asin extends UnaryArithmeticBuiltinNode { + public static final class Asin extends UnaryArithmetic { @Child private CHypot chypot; - public Asin() { - super(RType.Double); - } - private void ensureChypot() { if (chypot == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); @@ -371,13 +314,8 @@ public class TrigExpFunctions { } @RBuiltin(name = "asinh", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Asinh extends UnaryArithmeticBuiltinNode { - - @Child private Asin asinNode = AsinNodeGen.create(); - - public Asinh() { - super(RType.Double); - } + public static final class Asinh extends UnaryArithmetic { + @Child private Asin asinNode = new Asin(); @Override public double op(double x) { @@ -392,13 +330,8 @@ public class TrigExpFunctions { } @RBuiltin(name = "acos", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Acos extends UnaryArithmeticBuiltinNode { - - public Acos() { - super(RType.Double); - } - - @Child private Asin asinNode = AsinNodeGen.create(); + public static final class Acos extends UnaryArithmetic { + @Child private Asin asinNode = new Asin(); @Override public double op(double op) { @@ -413,13 +346,9 @@ public class TrigExpFunctions { } @RBuiltin(name = "acosh", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Acosh extends UnaryArithmeticBuiltinNode { - - public Acosh() { - super(RType.Double); - } + public static final class Acosh extends UnaryArithmetic { - @Child private Acos acosNode = AcosNodeGen.create(); + @Child private Acos acosNode = new Acos(); @Override public double op(double x) { @@ -434,12 +363,7 @@ public class TrigExpFunctions { } @RBuiltin(name = "atan", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Atan extends UnaryArithmeticBuiltinNode { - - public Atan() { - super(RType.Double); - } - + public static final class Atan extends UnaryArithmetic { @Override public double op(double x) { return Math.atan(x); @@ -464,13 +388,9 @@ public class TrigExpFunctions { } @RBuiltin(name = "atanh", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) - public abstract static class Atanh extends UnaryArithmeticBuiltinNode { - - public Atanh() { - super(RType.Double); - } + public static final class Atanh extends UnaryArithmetic { - @Child private Atan atanNode = AtanNodeGen.create(); + @Child private Atan atanNode = new Atan(); @Override public double op(double x) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Trunc.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Trunc.java index 38cd8445a449e3c64efebd188827d577d9f505a6..7cc77fc5a68beca55a1a37d3001f7b3b99cfe2f8 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Trunc.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Trunc.java @@ -22,51 +22,25 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.runtime.RDispatch.MATH_GROUP_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_ARITHMETIC; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNode; -import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNodeGen; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticNode; -import com.oracle.truffle.r.nodes.unary.UnaryArithmeticNodeGen; -import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; @RBuiltin(name = "trunc", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE_ARITHMETIC) -public abstract class Trunc extends RBuiltinNode.Arg1 { +public final class Trunc extends UnaryArithmetic { - private static final UnaryArithmeticFactory TRUNC = TruncArithmetic::new; + public static final UnaryArithmeticFactory TRUNC = Trunc::new; - @Child private BoxPrimitiveNode boxPrimitive = BoxPrimitiveNodeGen.create(); - @Child private UnaryArithmeticNode trunc = UnaryArithmeticNodeGen.create(TRUNC, RError.Message.NON_NUMERIC_MATH, RType.Double); - - static { - Casts casts = new Casts(Trunc.class); - casts.arg("x").defaultError(RError.Message.NON_NUMERIC_MATH).mustNotBeNull().mustBe(complexValue().not(), RError.Message.UNIMPLEMENTED_COMPLEX_FUN).mustBe(numericValue()).asDoubleVector(true, - true, true); - } - - @Specialization - protected Object trunc(Object value) { - return trunc.execute(boxPrimitive.execute(value)); - } - - private static final class TruncArithmetic extends Round.RoundArithmetic { - - @Override - public double op(double op) { - if (op > 0) { - return Math.floor(op); - } else { - return Math.ceil(op); - } + @Override + public double op(double op) { + if (op > 0) { + return Math.floor(op); + } else { + return Math.ceil(op); } } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java index f1c99a37f020f0390bb88803583aa6f11cf51e06..807c73493255f998ad075b42347efd0e83685640 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java @@ -55,7 +55,7 @@ abstract class AccessFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { @Child private ExtractListElement extractListElement = ExtractListElement.create(); - @Specialization(limit = "2", guards = {"isSimpleList(list)", "list.getNames() == cachedNames", "field == cachedField"}) + @Specialization(limit = "2", guards = {"getNamesNode.getNames(list) == cachedNames", "field == cachedField"}) public Object doList(RList list, @SuppressWarnings("unused") String field, @SuppressWarnings("unused") @Cached("list.getNames()") RStringVector cachedNames, @SuppressWarnings("unused") @Cached("field") String cachedField, @@ -66,7 +66,7 @@ abstract class AccessFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { return extractListElement.execute(list, index); } - @Specialization(replaces = "doList", guards = {"isSimpleList(list)", "list.getNames() != null"}) + @Specialization(replaces = "doList") public Object doListDynamic(RList list, String field) { int index = getIndex(getNamesNode.getNames(list), field); if (index == -1) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java index 569af73be2fc13c8a9178766d0a4c773b7b30f4f..987dc95344f9af694fe34fb37f26f66050602580 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java @@ -35,7 +35,6 @@ import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNames import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ConvertIndexNodeGen; import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ConvertValueNodeGen; import com.oracle.truffle.r.nodes.function.ClassHierarchyNode; -import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RIntVector; @@ -68,16 +67,14 @@ class SpecialsUtils { */ abstract static class SubscriptSpecialCommon extends Node { - @Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false); - protected final boolean inReplacement; protected SubscriptSpecialCommon(boolean inReplacement) { this.inReplacement = inReplacement; } - protected boolean simpleVector(RAbstractVector vector) { - return classHierarchy.execute(vector) == null; + protected boolean simpleVector(@SuppressWarnings("unused") RAbstractVector vector) { + return true; } /** @@ -123,11 +120,10 @@ class SpecialsUtils { */ abstract static class ListFieldSpecialBase extends RNode { - @Child private ClassHierarchyNode hierarchyNode = ClassHierarchyNode.create(); @Child protected GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); protected final boolean isSimpleList(RList list) { - return hierarchyNode.execute(list) == null; + return true; } protected static int getIndex(RStringVector names, String field) { @@ -160,6 +156,8 @@ class SpecialsUtils { protected abstract RNode getDelegate(); + public abstract Object execute(Object value); + @Specialization protected static int convertInteger(int value) { return value; @@ -168,7 +166,7 @@ class SpecialsUtils { @Specialization(rewriteOn = IllegalArgumentException.class) protected int convertDouble(double value) { int intValue = (int) value; - if (intValue == 0) { + if (intValue <= 0) { /* * Conversion from double to an index differs in subscript and subset for values in * the ]0..1[ range (subscript interprets 0.1 as 1, whereas subset treats it as 0). diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java index da94856c0023cec421e7da1ac26726604dc07225..8662349032804dab12676c37d925735373219b12 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java @@ -68,7 +68,9 @@ abstract class SubscriptSpecialBase extends SubscriptSpecialCommon { super(inReplacement); } - protected abstract Object execute(VirtualFrame frame, Object vec, Object index); + public abstract Object execute(VirtualFrame frame, Object vec, Object index); + + public abstract Object execute(VirtualFrame frame, Object vec, int index); @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"}) protected int access(RAbstractIntVector vector, int index) { @@ -103,6 +105,8 @@ abstract class SubscriptSpecial2Base extends SubscriptSpecial2Common { public abstract Object execute(VirtualFrame frame, Object vector, Object index1, Object index2); + public abstract Object execute(VirtualFrame frame, Object vec, int index1, int index2); + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"}) protected int access(RAbstractIntVector vector, int index1, int index2) { return vector.getDataAt(matrixIndex(vector, index1, index2)); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java index 3d1e5f25835fc4646394354aa6611cd29187b945..f1145838fa630caadf6b3400b8e3aa8c0023b86c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java @@ -66,7 +66,7 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { return value != RNull.instance && !(value instanceof RList); } - @Specialization(limit = "2", guards = {"isSimpleList(list)", "!list.isShared()", "list.getNames() == cachedNames", "field == cachedField", "isNotRNullRList(value)"}) + @Specialization(limit = "2", guards = {"!list.isShared()", "getNamesNode.getNames(list) == cachedNames", "field == cachedField", "isNotRNullRList(value)"}) public Object doList(RList list, @SuppressWarnings("unused") String field, Object value, @SuppressWarnings("unused") @Cached("list.getNames()") RStringVector cachedNames, @SuppressWarnings("unused") @Cached("field") String cachedField, @@ -83,7 +83,7 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { return list; } - @Specialization(replaces = "doList", guards = {"isSimpleList(list)", "!list.isShared()", "list.getNames() != null", "isNotRNullRList(value)"}) + @Specialization(replaces = "doList", guards = {"!list.isShared()", "list.getNames() != null", "isNotRNullRList(value)"}) public RList doListDynamic(RList list, String field, Object value) { int index = getIndex(getNamesNode.getNames(list), field); if (index == -1) { diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java index ae709c5c8b51eb5655f9229e11cd69f35f90eb9b..74d69baca5ae8a6dff88d506f7443032cb6baebc 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java @@ -186,8 +186,8 @@ public class SpecialCallTest extends TestBase { assertCallCounts("a <- c(1,2,3,4)", "a[0]", 1, 0, 1, 0); assertCallCounts("a <- c(1,2,3,4); b <- -1", "a[b]", 1, 0, 1, 0); assertCallCounts("a <- c(1,2,3,4)", "a[NA_integer_]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[-1]", 2, 0, 2, 0); - assertCallCounts("a <- c(1,2,3,4)", "a[-1]", 0, 2, 0, 2); // "-1" is a unary expression assertCallCounts("a <- c(1,2,3,4)", "a[drop=T, 1]", 0, 1, 0, 1); assertCallCounts("a <- c(1,2,3,4)", "a[drop=F, 1]", 0, 1, 0, 1); assertCallCounts("a <- c(1,2,3,4)", "a[1, drop=F]", 0, 1, 0, 1); @@ -223,7 +223,7 @@ public class SpecialCallTest extends TestBase { assertCallCounts("a <- c(1,2,3,4); b <- -1", "a[b] <- 1", 1, 0, 1, 1); assertCallCounts("a <- c(1,2,3,4)", "a[NA_integer_] <- 1", 1, 0, 1, 1); - assertCallCounts("a <- c(1,2,3,4)", "a[-1] <- 1", 0, 2, 0, 3); // "-1" is a unary expression + assertCallCounts("a <- c(1,2,3,4)", "a[-1] <- 1", 2, 0, 2, 1); assertCallCounts("a <- c(1,2,3,4)", "a[drop=T, 1] <- 1", 0, 1, 0, 2); assertCallCounts("a <- c(1,2,3,4)", "a[drop=F, 1] <- 1", 0, 1, 0, 2); assertCallCounts("a <- c(1,2,3,4)", "a[1, drop=F] <- 1", 0, 1, 0, 2); @@ -254,7 +254,7 @@ public class SpecialCallTest extends TestBase { assertCallCounts("a <- 1", "('asdf')", 1, 0, 1, 0); assertCallCounts("a <- 1; b <- 2", "(a + b)", 2, 0, 2, 0); assertCallCounts("a <- 1; b <- 2; c <- 3", "a + (b + c)", 3, 0, 3, 0); - assertCallCounts("a <- 1; b <- 2; c <- 1:5", "a + (b + c)", 3, 0, 0, 3); + assertCallCounts("a <- 1; b <- 2; c <- 1:5", "a + (b + c)", 3, 0, 3, 0); } private static void assertCallCounts(String test, int initialSpecialCount, int initialNormalCount, int finalSpecialCount, int finalNormalCount) { diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/UnaryArithmeticNodeTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/UnaryArithmeticNodeTest.java index c575efa26a681a498bc5253cb5fe78ff1281aacc..ffee9580cf61c3e3f9b117dc0bc13fe7d8802efa 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/UnaryArithmeticNodeTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/UnaryArithmeticNodeTest.java @@ -52,7 +52,7 @@ import org.junit.runner.RunWith; import com.oracle.truffle.api.object.DynamicObject; import com.oracle.truffle.r.nodes.builtin.base.Ceiling; import com.oracle.truffle.r.nodes.builtin.base.Floor; -import com.oracle.truffle.r.nodes.builtin.base.Round; +import com.oracle.truffle.r.nodes.builtin.base.Trunc; import com.oracle.truffle.r.nodes.test.TestUtilities.NodeHandle; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticNode; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticNodeGen; @@ -64,6 +64,7 @@ import com.oracle.truffle.r.runtime.data.RSequence; import com.oracle.truffle.r.runtime.data.RShareable; import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; /** @@ -77,7 +78,7 @@ public class UnaryArithmeticNodeTest extends BinaryVectorTest { // to make sure this file is recognized as a test } - public static final UnaryArithmeticFactory[] ALL = new UnaryArithmeticFactory[]{NEGATE, Round.ROUND, Floor.FLOOR, Ceiling.CEILING, PLUS}; + public static final UnaryArithmeticFactory[] ALL = new UnaryArithmeticFactory[]{NEGATE, PLUS, Floor.FLOOR, Ceiling.CEILING, Trunc.TRUNC}; @DataPoints public static final UnaryArithmeticFactory[] UNARY = ALL; @@ -100,7 +101,7 @@ public class UnaryArithmeticNodeTest extends BinaryVectorTest { // sharing does not work if a is a scalar vector assumeThat(true, is(isShareable(operand, operand.getRType()))); - RType resultType = getArgumentType(operand); + RType resultType = getArgumentType(factory, operand); Object sharedResult = null; if (isShareable(operand, resultType)) { sharedResult = operand; @@ -161,7 +162,7 @@ public class UnaryArithmeticNodeTest extends BinaryVectorTest { public void testPlusFolding(RAbstractVector originalOperand) { RAbstractVector operand = copy(originalOperand); assumeThat(operand, is(not(instanceOf(RScalarVector.class)))); - if (operand.getRType() == getArgumentType(operand)) { + if (operand.getRType() == getArgumentType(PLUS, operand)) { assertFold(true, operand, PLUS); } else { assertFold(false, operand, PLUS); @@ -172,8 +173,8 @@ public class UnaryArithmeticNodeTest extends BinaryVectorTest { public void testSequenceFolding() { assertFold(true, createIntSequence(1, 3, 10), NEGATE); assertFold(true, createDoubleSequence(1, 3, 10), NEGATE); - assertFold(false, createIntSequence(1, 3, 10), Round.ROUND, Floor.FLOOR, Ceiling.CEILING); - assertFold(false, createDoubleSequence(1, 3, 10), Round.ROUND, Floor.FLOOR, Ceiling.CEILING); + assertFold(false, createIntSequence(1, 3, 10), Floor.FLOOR, Ceiling.CEILING); + assertFold(false, createDoubleSequence(1, 3, 10), Floor.FLOOR, Ceiling.CEILING); } @Theory @@ -206,8 +207,9 @@ public class UnaryArithmeticNodeTest extends BinaryVectorTest { Assert.assertEquals(expectedAttributes, foundAttributes); } - private static RType getArgumentType(RAbstractVector operand) { - return RType.maxPrecedence(RType.Integer, operand.getRType()); + private static RType getArgumentType(UnaryArithmeticFactory factory, RAbstractVector operand) { + UnaryArithmetic operation = factory.createOperation(); + return operation.calculateResultType(RType.maxPrecedence(operation.getMinPrecedence(), operand.getRType())); } private static boolean isPrimitive(Object result) { @@ -248,7 +250,6 @@ public class UnaryArithmeticNodeTest extends BinaryVectorTest { } private static NodeHandle<UnaryArithmeticNode> create(UnaryArithmeticFactory factory) { - return createHandle(UnaryArithmeticNodeGen.create(factory, null), - (node, args) -> node.execute(args[0])); + return createHandle(UnaryArithmeticNodeGen.create(factory), (node, args) -> node.execute(args[0])); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java index 56f24c6840969e6049dfd536a99694f3c015011e..173f3a04867225f35d2a279bcf02c33869c853a4 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java @@ -122,7 +122,7 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 { throw error(RError.Message.ARGUMENT_EMPTY, 2); } else { CompilerDirectives.transferToInterpreterAndInvalidate(); - return UnaryArithmeticNodeGen.create(unary, RError.Message.INVALID_ARG_TYPE_UNARY); + return UnaryArithmeticNodeGen.create(unary); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticSpecial.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticSpecial.java index ee507eba25d5427948199f6d1d3654836b3e5549..df79a03c484b97c66b4900950a24a7d3e0abd0e1 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticSpecial.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticSpecial.java @@ -23,18 +23,19 @@ package com.oracle.truffle.r.nodes.binary; import com.oracle.truffle.api.CompilerDirectives; - import com.oracle.truffle.api.dsl.Cached; -import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.binary.BinaryArithmeticSpecialNodeGen.IntegerBinaryArithmeticSpecialNodeGen; +import com.oracle.truffle.r.nodes.unary.UnaryArithmeticSpecialNodeGen; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RSpecialFactory; import com.oracle.truffle.r.runtime.nodes.RNode; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; import com.oracle.truffle.r.runtime.ops.BinaryArithmeticFactory; +import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; /** * Fast-path for scalar values: these cannot have any class attribute. Note: we intentionally use @@ -47,37 +48,58 @@ import com.oracle.truffle.r.runtime.ops.BinaryArithmeticFactory; public abstract class BinaryArithmeticSpecial extends RNode { private final boolean handleNA; + private final BinaryArithmeticFactory binaryFactory; + private final UnaryArithmeticFactory unaryFactory; + @Child private BinaryArithmetic operation; - public BinaryArithmeticSpecial(BinaryArithmeticFactory opFactory) { - this.operation = opFactory.createOperation(); - this.handleNA = !(opFactory == BinaryArithmetic.POW || opFactory == BinaryArithmetic.MOD); + public BinaryArithmeticSpecial(BinaryArithmeticFactory binaryFactory, UnaryArithmeticFactory unaryFactory) { + this.binaryFactory = binaryFactory; + this.unaryFactory = unaryFactory; + this.operation = binaryFactory.createOperation(); + this.handleNA = !(binaryFactory == BinaryArithmetic.POW || binaryFactory == BinaryArithmetic.MOD); } - public static RSpecialFactory createSpecialFactory(BinaryArithmeticFactory opFactory) { - boolean handleIntegers = !(opFactory == BinaryArithmetic.POW || opFactory == BinaryArithmetic.DIV); - if (handleIntegers) { - return (signature, arguments, inReplacement) -> signature.getNonNullCount() == 0 && arguments.length == 2 - ? IntegerBinaryArithmeticSpecialNodeGen.create(opFactory, arguments[0], arguments[1]) : null; - } else { - return (signature, arguments, inReplacement) -> signature.getNonNullCount() == 0 && arguments.length == 2 - ? BinaryArithmeticSpecialNodeGen.create(opFactory, arguments[0], arguments[1]) : null; - } + public static RSpecialFactory createSpecialFactory(BinaryArithmeticFactory binaryFactory, UnaryArithmeticFactory unaryFactory) { + return (signature, arguments, inReplacement) -> { + if (signature.getNonNullCount() == 0) { + if (arguments.length == 2) { + boolean handleIntegers = !(binaryFactory == BinaryArithmetic.POW || binaryFactory == BinaryArithmetic.DIV); + if (handleIntegers) { + return IntegerBinaryArithmeticSpecialNodeGen.create(binaryFactory, unaryFactory, arguments[0], arguments[1]); + } else { + return BinaryArithmeticSpecialNodeGen.create(binaryFactory, unaryFactory, arguments[0], arguments[1]); + } + } else if (arguments.length == 1 && unaryFactory != null) { + return UnaryArithmeticSpecialNodeGen.create(unaryFactory, arguments[0]); + } + } + return null; + }; } @Specialization - protected double doDoubles(double left, double right) { - if (RRuntime.isNA(left) || RRuntime.isNA(right)) { + protected double doDoubles(double left, double right, + @Cached("createBinaryProfile()") ConditionProfile leftNanProfile, + @Cached("createBinaryProfile()") ConditionProfile rightNaProfile) { + if (leftNanProfile.profile(Double.isNaN(left))) { + checkFullCallNeededOnNA(); + return left; + } else if (rightNaProfile.profile(RRuntime.isNA(right))) { checkFullCallNeededOnNA(); - return isNaN(left) ? Double.NaN : RRuntime.DOUBLE_NA; + return RRuntime.DOUBLE_NA; } return getOperation().op(left, right); } - @Fallback - @SuppressWarnings("unused") - protected double doFallback(Object left, Object right) { - throw RSpecialFactory.throwFullCallNeeded(); + protected BinaryArithmeticNode createFull() { + return BinaryArithmeticNodeGen.create(binaryFactory, unaryFactory); + } + + @Specialization + protected Object doFallback(VirtualFrame frame, Object left, Object right, + @Cached("createFull()") BinaryArithmeticNode binary) { + return binary.call(frame, left, right); } protected BinaryArithmetic getOperation() { @@ -100,11 +122,11 @@ public abstract class BinaryArithmeticSpecial extends RNode { */ abstract static class IntegerBinaryArithmeticSpecial extends BinaryArithmeticSpecial { - IntegerBinaryArithmeticSpecial(BinaryArithmeticFactory opFactory) { - super(opFactory); + IntegerBinaryArithmeticSpecial(BinaryArithmeticFactory binaryFactory, UnaryArithmeticFactory unaryFactory) { + super(binaryFactory, unaryFactory); } - @Specialization + @Specialization(insertBefore = "doFallback") public int doIntegers(int left, int right, @Cached("createBinaryProfile()") ConditionProfile naProfile) { if (naProfile.profile(RRuntime.isNA(left) || RRuntime.isNA(right))) { @@ -114,7 +136,7 @@ public abstract class BinaryArithmeticSpecial extends RNode { return getOperation().op(left, right); } - @Specialization + @Specialization(insertBefore = "doFallback") public double doIntDouble(int left, double right, @Cached("createBinaryProfile()") ConditionProfile naProfile) { if (naProfile.profile(RRuntime.isNA(left) || RRuntime.isNA(right))) { @@ -124,12 +146,16 @@ public abstract class BinaryArithmeticSpecial extends RNode { return getOperation().op(left, right); } - @Specialization + @Specialization(insertBefore = "doFallback") public double doDoubleInt(double left, int right, - @Cached("createBinaryProfile()") ConditionProfile naProfile) { - if (naProfile.profile(RRuntime.isNA(left) || RRuntime.isNA(right))) { + @Cached("createBinaryProfile()") ConditionProfile leftNanProfile, + @Cached("createBinaryProfile()") ConditionProfile rightNaProfile) { + if (leftNanProfile.profile(Double.isNaN(left))) { + checkFullCallNeededOnNA(); + return left; + } else if (rightNaProfile.profile(RRuntime.isNA(right))) { checkFullCallNeededOnNA(); - return isNaN(left) ? Double.NaN : RRuntime.DOUBLE_NA; + return RRuntime.DOUBLE_NA; } return getOperation().op(left, right); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java index e87a36945f15d09cdabd87d1c33cba3d6a790a3f..3dc749f4e1cce69787d85f842b1f47802d27ec53 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java @@ -25,11 +25,13 @@ package com.oracle.truffle.r.nodes.function; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.NodeChild; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.NodeCost; import com.oracle.truffle.api.nodes.NodeInfo; -import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.api.source.SourceSection; import com.oracle.truffle.r.nodes.access.variables.LocalReadVariableNode; @@ -39,13 +41,14 @@ import com.oracle.truffle.r.runtime.FastROptions; import com.oracle.truffle.r.runtime.RDeparse; import com.oracle.truffle.r.runtime.RDispatch; import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RVisibility; import com.oracle.truffle.r.runtime.Utils; import com.oracle.truffle.r.runtime.builtins.RBuiltinDescriptor; import com.oracle.truffle.r.runtime.builtins.RSpecialFactory; import com.oracle.truffle.r.runtime.context.RContext; +import com.oracle.truffle.r.runtime.data.RAttributable; import com.oracle.truffle.r.runtime.data.RFunction; -import com.oracle.truffle.r.runtime.data.RPromise; import com.oracle.truffle.r.runtime.nodes.RBaseNode; import com.oracle.truffle.r.runtime.nodes.RNode; import com.oracle.truffle.r.runtime.nodes.RSyntaxCall; @@ -57,14 +60,12 @@ import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; final class PeekLocalVariableNode extends RNode implements RSyntaxLookup { @Child private LocalReadVariableNode read; + @Child private SetVisibilityNode visibility; - private final ConditionProfile isPromiseProfile = ConditionProfile.createBinaryProfile(); private final ValueProfile valueProfile = ValueProfile.createClassProfile(); - @Child private SetVisibilityNode visibility; - PeekLocalVariableNode(String name) { - this.read = LocalReadVariableNode.create(Utils.intern(name), false); + this.read = LocalReadVariableNode.create(Utils.intern(name), true); } @Override @@ -73,13 +74,6 @@ final class PeekLocalVariableNode extends RNode implements RSyntaxLookup { if (value == null) { throw RSpecialFactory.throwFullCallNeeded(); } - if (isPromiseProfile.profile(value instanceof RPromise)) { - RPromise promise = (RPromise) value; - if (!promise.isEvaluated()) { - throw RSpecialFactory.throwFullCallNeeded(); - } - return valueProfile.profile(promise.getValue()); - } return valueProfile.profile(value); } @@ -117,6 +111,46 @@ final class PeekLocalVariableNode extends RNode implements RSyntaxLookup { } } +@NodeChild(value = "delegate", type = RNode.class) +abstract class ClassCheckNode extends RNode { + + public abstract RNode getDelegate(); + + @Override + protected RSyntaxNode getRSyntaxNode() { + return getDelegate().asRSyntaxNode(); + } + + @Specialization + protected static int doInt(int value) { + return value; + } + + @Specialization + protected static double doDouble(double value) { + return value; + } + + @Specialization + protected static byte doLogical(byte value) { + return value; + } + + @Specialization + protected static String doString(String value) { + return value; + } + + @Specialization + public Object doGeneric(Object value, + @Cached("create()") ClassHierarchyNode classHierarchy) { + if (classHierarchy.execute(value) != null) { + throw RSpecialFactory.throwFullCallNeeded(); + } + return value; + } +} + @NodeInfo(cost = NodeCost.NONE) public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode, RSyntaxCall { @@ -246,12 +280,10 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode if (arg instanceof RSyntaxLookup) { String lookup = ((RSyntaxLookup) arg).getIdentifier(); if (ArgumentsSignature.VARARG_NAME.equals(lookup)) { + // cannot map varargs return null; } if (i < evaluatedArgs) { - // not quite correct: - // || (dispatch == RDispatch.DEFAULT - // && builtinDescriptor.evaluatesArg(i)) localArguments[i] = arg.asRNode(); } else { localArguments[i] = new PeekLocalVariableNode(lookup); @@ -262,6 +294,16 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode assert arg instanceof RCallSpecialNode; localArguments[i] = arg.asRNode(); } + if (dispatch.isGroupGeneric() || dispatch == RDispatch.INTERNAL_GENERIC && i == 0) { + if (localArguments[i] instanceof RSyntaxConstant) { + Object value = ((RSyntaxConstant) localArguments[i]).getValue(); + if (value instanceof RAttributable && ((RAttributable) value).getAttr(RRuntime.CLASS_ATTR_KEY) != null) { + return null; + } + } else { + localArguments[i] = ClassCheckNodeGen.create(localArguments[i]); + } + } } } RNode special = specialCall.create(signature, localArguments, inReplace); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticBuiltinNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticBuiltinNode.java index b251f47a8409c50103beae9c99b8ebc40e918665..43a61ee7c7fcaca0e166ddbe1d2d365ec94d7c1b 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticBuiltinNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticBuiltinNode.java @@ -22,96 +22,25 @@ */ package com.oracle.truffle.r.nodes.unary; -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNode; -import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNodeGen; +import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RType; -import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; -public abstract class UnaryArithmeticBuiltinNode extends RBuiltinNode.Arg1 implements UnaryArithmeticFactory { +public final class UnaryArithmeticBuiltinNode extends RBuiltinNode.Arg1 { - @Child private BoxPrimitiveNode boxPrimitive = BoxPrimitiveNodeGen.create(); - @Child private UnaryArithmeticNode unaryNode; - - protected UnaryArithmeticBuiltinNode(RType minPrecedence, RError.Message error, Object[] errorArgs) { - unaryNode = UnaryArithmeticNodeGen.create(this, minPrecedence, error, errorArgs); + { + Casts casts = new Casts(UnaryArithmeticBuiltinNode.class); + casts.arg(0).boxPrimitive(); } - protected UnaryArithmeticBuiltinNode(RType minPrecedence) { - unaryNode = UnaryArithmeticNodeGen.create(this, minPrecedence, RError.Message.ARGUMENTS_PASSED, new Object[]{0, "'" + getRBuiltin().name() + "'", 1}); - } + @Child private UnaryArithmeticNode unaryNode; - @Specialization - public Object calculateUnboxed(Object value) { - return unaryNode.execute(boxPrimitive.execute(value)); + public UnaryArithmeticBuiltinNode(UnaryArithmeticFactory factory) { + unaryNode = UnaryArithmeticNodeGen.create(factory); } @Override - public UnaryArithmetic createOperation() { - return new UnaryArithmetic() { - - @Override - public RType calculateResultType(RType argumentType) { - return UnaryArithmeticBuiltinNode.this.calculateResultType(argumentType); - } - - @Override - public double opd(double re, double im) { - return UnaryArithmeticBuiltinNode.this.opd(re, im); - } - - @Override - public int op(byte op) { - return UnaryArithmeticBuiltinNode.this.op(op); - } - - @Override - public int op(int op) { - return UnaryArithmeticBuiltinNode.this.op(op); - } - - @Override - public double op(double op) { - return UnaryArithmeticBuiltinNode.this.op(op); - } - - @Override - public RComplex op(double re, double im) { - return UnaryArithmeticBuiltinNode.this.op(re, im); - } - }; - } - - protected RType calculateResultType(RType argumentType) { - return argumentType; - } - - @SuppressWarnings("unused") - protected int op(byte op) { - throw new UnsupportedOperationException(); - } - - @SuppressWarnings("unused") - protected int op(int op) { - throw new UnsupportedOperationException(); - } - - @SuppressWarnings("unused") - protected double op(double op) { - throw new UnsupportedOperationException(); - } - - @SuppressWarnings("unused") - protected RComplex op(double re, double im) { - throw new UnsupportedOperationException(); - } - - @SuppressWarnings("unused") - protected double opd(double re, double im) { - throw new UnsupportedOperationException(); + public Object execute(VirtualFrame frame, Object value) { + return unaryNode.execute(value); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java index 2e4061d8fe6cf3a61dcbd7d29f35a37d41ecded6..1f0892d7285603e2bf2950d8ba38a5b7df910ed3 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java @@ -29,40 +29,24 @@ import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.primitive.UnaryMapNode; import com.oracle.truffle.r.nodes.profile.TruffleBoundaryNode; -import com.oracle.truffle.r.runtime.RError.Message; -import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; public abstract class UnaryArithmeticNode extends UnaryNode { protected final UnaryArithmeticFactory unary; - private final Message error; - private final Object[] errorArgs; - protected final RType minPrecedence; - public UnaryArithmeticNode(UnaryArithmeticFactory factory, RType minPrecedence, Message error, Object... errorArgs) { + public UnaryArithmeticNode(UnaryArithmeticFactory factory) { this.unary = factory; - this.error = error; - this.errorArgs = errorArgs; - this.minPrecedence = minPrecedence; - } - - public UnaryArithmeticNode(UnaryArithmeticFactory factory, Message error, RType minPrecedence) { - this.unary = factory; - this.error = error; - this.minPrecedence = minPrecedence; - this.errorArgs = null; - } - - public UnaryArithmeticNode(UnaryArithmeticFactory factory, Message error) { - this(factory, error, RType.Integer); } public abstract Object execute(Object value); @@ -75,17 +59,17 @@ public abstract class UnaryArithmeticNode extends UnaryNode { protected UnaryMapNode createCachedFast(Object operand) { if (isNumericVector(operand)) { - return createCached(unary.createOperation(), operand, minPrecedence); + return createCached(unary.createOperation(), operand); } return null; } - protected static UnaryMapNode createCached(UnaryArithmetic arithmetic, Object operand, RType minPrecedence) { + protected static UnaryMapNode createCached(UnaryArithmetic arithmetic, Object operand) { if (operand instanceof RAbstractVector) { RAbstractVector castOperand = (RAbstractVector) operand; RType operandType = castOperand.getRType(); if (operandType.isNumeric()) { - RType type = RType.maxPrecedence(operandType, minPrecedence); + RType type = RType.maxPrecedence(operandType, arithmetic.getMinPrecedence()); RType resultType = arithmetic.calculateResultType(type); return UnaryMapNode.create(new ScalarUnaryArithmeticNode(arithmetic), castOperand, type, resultType); } @@ -101,24 +85,24 @@ public abstract class UnaryArithmeticNode extends UnaryNode { @TruffleBoundary protected Object doGeneric(Object operand, @Cached("unary.createOperation()") UnaryArithmetic arithmetic, - @Cached("new(createCached(arithmetic, operand, minPrecedence), minPrecedence)") GenericNumericVectorNode generic) { + @Cached("new(createCached(arithmetic, operand))") GenericNumericVectorNode generic) { RAbstractVector operandVector = (RAbstractVector) operand; return generic.get(arithmetic, operandVector).apply(operandVector); } + @Override + public RBaseNode getErrorContext() { + return this; + } + @Fallback - protected Object invalidArgType(@SuppressWarnings("unused") Object operand) { + protected Object invalidArgType(Object operand) { CompilerDirectives.transferToInterpreter(); - if (errorArgs == null || errorArgs.length == 0) { - throw error(error); - } else if (errorArgs.length == 1) { - throw error(error, errorArgs[0]); - } else if (errorArgs.length == 2) { - throw error(error, errorArgs[0], errorArgs[1]); - } else if (errorArgs.length == 3) { - throw error(error, errorArgs[0], errorArgs[1], errorArgs[2]); + UnaryArithmetic op = unary.createOperation(); + if (operand instanceof RMissing) { + throw error(RError.Message.ARGUMENTS_PASSED, 0, "'" + op.getClass().getSimpleName().toLowerCase() + "'", 1); } else { - throw RInternalError.shouldNotReachHere("too many error arguments in UnaryArithmeticNode"); + throw error(op.getArgumentError()); } } @@ -126,17 +110,14 @@ public abstract class UnaryArithmeticNode extends UnaryNode { @Child private UnaryMapNode cached; - private final RType minPrecedence; - - public GenericNumericVectorNode(UnaryMapNode cachedOperation, RType minPrecedence) { + public GenericNumericVectorNode(UnaryMapNode cachedOperation) { this.cached = cachedOperation; - this.minPrecedence = minPrecedence; } public UnaryMapNode get(UnaryArithmetic arithmetic, RAbstractVector operand) { UnaryMapNode next = cached; if (!next.isSupported(operand)) { - next = cached.replace(createCached(arithmetic, operand, minPrecedence)); + next = cached.replace(createCached(arithmetic, operand)); } return next; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticSpecial.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticSpecial.java new file mode 100644 index 0000000000000000000000000000000000000000..1b42ac2e6224cb603a925c75a80c4e48506108f7 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticSpecial.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.unary; + +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.ImportStatic; +import com.oracle.truffle.api.dsl.NodeChild; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNode; +import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.builtins.RSpecialFactory; +import com.oracle.truffle.r.runtime.nodes.RNode; +import com.oracle.truffle.r.runtime.ops.UnaryArithmetic; +import com.oracle.truffle.r.runtime.ops.UnaryArithmeticFactory; + +/** + * Fast-path for scalar values: these cannot have any class attribute. Note: we intentionally use + * empty type system to avoid conversions to vector types. + */ +@ImportStatic(RType.class) +@NodeChild(value = "operand", type = RNode.class) +public abstract class UnaryArithmeticSpecial extends RNode { + + private final UnaryArithmeticFactory unaryFactory; + + @Child protected UnaryArithmetic operation; + + protected UnaryArithmeticSpecial(UnaryArithmeticFactory unaryFactory) { + this.unaryFactory = unaryFactory; + this.operation = unaryFactory.createOperation(); + } + + public static RSpecialFactory createSpecialFactory(UnaryArithmeticFactory unaryFactory) { + return (signature, arguments, inReplacement) -> signature.getNonNullCount() == 0 && arguments.length == 1 + ? UnaryArithmeticSpecialNodeGen.create(unaryFactory, arguments[0]) : null; + } + + @Specialization + protected double doDoubles(double operand, + @Cached("createBinaryProfile()") ConditionProfile naProfile) { + if (naProfile.profile(RRuntime.isNA(operand))) { + return operand; + } + return getOperation().op(operand); + } + + protected UnaryArithmeticNode createFull() { + return UnaryArithmeticNodeGen.create(unaryFactory); + } + + @Specialization(guards = "operation.getMinPrecedence() == Integer") + public int doIntegers(int operand, + @Cached("createBinaryProfile()") ConditionProfile naProfile) { + if (naProfile.profile(RRuntime.isNA(operand))) { + return RRuntime.INT_NA; + } + return getOperation().op(operand); + } + + @Specialization(guards = "operation.getMinPrecedence() == Double") + public double doIntegersDouble(int operand, + @Cached("createBinaryProfile()") ConditionProfile naProfile) { + if (naProfile.profile(RRuntime.isNA(operand))) { + return RRuntime.INT_NA; + } + return getOperation().op((double) operand); + } + + @Specialization + protected Object doFallback(Object operand, + @Cached("create()") BoxPrimitiveNode boxPrimitive, + @Cached("createFull()") UnaryArithmeticNode unary) { + return unary.execute(boxPrimitive.execute(operand)); + } + + protected UnaryArithmetic getOperation() { + return operation; + } +} diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java index d5a95388201ad5c3bc91b398e87617eee8c36532..16a847179a575ac1d909fd34eeebaacf9ef677ab 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java @@ -305,7 +305,7 @@ public final class RError extends RuntimeException { MORE_SUPPLIED_REPLACE("more elements supplied than there are to replace"), NA_SUBSCRIPTED("NAs are not allowed in subscripted assignments"), INVALID_ARG_TYPE("invalid argument type"), - INVALID_ARG_TYPE_UNARY("invalid argument to unary operator"), + INVALID_ARG_UNARY("invalid argument to unary operator"), VECTOR_SIZE_NEGATIVE("vector size cannot be negative"), VECTOR_SIZE_NA("vector size cannot be NA"), VECTOR_SIZE_NA_NAN("vector size cannot be NA/NaN"), diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/Operation.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/Operation.java index 96ad234f0340a3f0d7b856877dc6c551c1a829a6..475484aa6cb034290bb28759ac3de978ed7cd2c6 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/Operation.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/Operation.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -24,6 +24,7 @@ package com.oracle.truffle.r.runtime.ops; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.ReturnException; import com.oracle.truffle.r.runtime.nodes.RBaseNode; public class Operation extends RBaseNode { @@ -45,6 +46,12 @@ public class Operation extends RBaseNode { } public static RuntimeException handleException(Throwable e) { - throw e instanceof RError ? (RError) e : RInternalError.shouldNotReachHere(e, "only RErrors should be thrown by arithmetic ops"); + if (e instanceof RError) { + throw (RError) e; + } else if (e instanceof ReturnException) { + throw (ReturnException) e; + } else { + throw RInternalError.shouldNotReachHere(e, "only RErrors or ReturnExceptions should be thrown by arithmetic ops"); + } } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java index 924015da662ec4140f5a4e1379c837ac914571ee..f32b5bd1895888e51be17e5273436b3f63572c36 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/UnaryArithmetic.java @@ -7,16 +7,22 @@ * Copyright (c) 1998, Ross Ihaka * Copyright (c) 1998-2012, The R Core Team * Copyright (c) 2005, The R Foundation - * Copyright (c) 2013, 2016, Oracle and/or its affiliates + * Copyright (c) 2013, 2017, Oracle and/or its affiliates * * All rights reserved. */ package com.oracle.truffle.r.runtime.ops; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RDataFactory; +/** + * Base class for the implementation of unary arithmetic operations. This covers functions like "+", + * "-", "sqrt", and many more. + */ public abstract class UnaryArithmetic extends Operation { public static final UnaryArithmeticFactory NEGATE = Negate::new; @@ -26,24 +32,62 @@ public abstract class UnaryArithmetic extends Operation { super(false, false); } + /** + * The lowest type with which this operation will be executed. E.g., if this is double, then + * integer arguments will be coerced to double before performing the actual operation. + */ + public RType getMinPrecedence() { + return RType.Double; + } + + /** + * Specifies the error that will be raised for invalid argument types. + */ + public Message getArgumentError() { + return RError.Message.NON_NUMERIC_MATH; + } + + /** + * Determines, for a given argument type (after coercion according to + * {@link #getMinPrecedence()}), the type of the return value. This is mainly intended to + * support operations that return double values for complex arguments. + */ public RType calculateResultType(RType argumentType) { return argumentType; } - public abstract int op(byte op); + public int op(@SuppressWarnings("unused") byte op) { + throw new UnsupportedOperationException(); + } - public abstract int op(int op); + public int op(@SuppressWarnings("unused") int op) { + throw new UnsupportedOperationException(); + } - public abstract double op(double op); + public double op(@SuppressWarnings("unused") double op) { + throw new UnsupportedOperationException(); + } - public abstract RComplex op(double re, double im); + public RComplex op(double re, double im) { + // default: perform operation on real and imaginary part + return RDataFactory.createComplex(op(re), op(im)); + } - @SuppressWarnings("unused") - public double opd(double re, double im) { + public double opd(@SuppressWarnings("unused") double re, @SuppressWarnings("unused") double im) { throw new UnsupportedOperationException(); } - public static class Negate extends UnaryArithmetic { + public static final class Negate extends UnaryArithmetic { + + @Override + public RType getMinPrecedence() { + return RType.Integer; + } + + @Override + public Message getArgumentError() { + return RError.Message.INVALID_ARG_UNARY; + } @Override public int op(int op) { @@ -59,14 +103,19 @@ public abstract class UnaryArithmetic extends Operation { public int op(byte op) { return -(int) op; } + } + + public static final class Plus extends UnaryArithmetic { @Override - public RComplex op(double re, double im) { - return RDataFactory.createComplex(op(re), op(im)); + public RType getMinPrecedence() { + return RType.Integer; } - } - public static class Plus extends UnaryArithmetic { + @Override + public Message getArgumentError() { + return RError.Message.INVALID_ARG_UNARY; + } @Override public int op(int op) { diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 1dabb9fb2ab419912e252d2a642c8dfc5248506e..35a047a9fdd066e8d0e3ea20bfdb7d9cbaa666e0 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -13355,16 +13355,16 @@ structure(integer(0), .Dim = c(0L, 3L), .Dimnames = list(NULL, [2,] 2 4 12 14 ##com.oracle.truffle.r.test.builtins.TestBuiltin_ceiling.testCeiling# -#{ ceiling(c(0.2,-3.4,NA,0/0,1/0)) } -[1] 1 -3 NA NaN Inf +#if (length(grep("FastR", R.Version()$version.string)) != 1) { 2+2i } else { { ceiling(1.1+1.9i); } } +[1] 2+2i ##com.oracle.truffle.r.test.builtins.TestBuiltin_ceiling.testCeiling# -#{ trunc("aaa"); } -Error in trunc("aaa") : non-numeric argument to mathematical function +#{ ceiling("aaa"); } +Error in ceiling("aaa") : non-numeric argument to mathematical function ##com.oracle.truffle.r.test.builtins.TestBuiltin_ceiling.testCeiling# -#{ trunc(1+1i); } -Error in trunc(1 + (0+1i)) : unimplemented complex function +#{ ceiling(c(0.2,-3.4,NA,0/0,1/0)) } +[1] 1 -3 NA NaN Inf ##com.oracle.truffle.r.test.builtins.TestBuiltin_ceiling.testCeiling# #{ typeof(ceiling(42L)); } @@ -24166,16 +24166,16 @@ Error: 4 arguments passed to .Internal(findInterval) which requires 5 Error: 4 arguments passed to .Internal(findInterval) which requires 5 ##com.oracle.truffle.r.test.builtins.TestBuiltin_floor.testFloor# -#{ floor(c(0.2,-3.4,NA,0/0,1/0)) } -[1] 0 -4 NA NaN Inf +#if (length(grep("FastR", R.Version()$version.string)) != 1) { 1+1i } else { { floor(1.1+1.9i); } } +[1] 1+1i ##com.oracle.truffle.r.test.builtins.TestBuiltin_floor.testFloor# -#{ trunc("aaa"); } -Error in trunc("aaa") : non-numeric argument to mathematical function +#{ floor("aaa"); } +Error in floor("aaa") : non-numeric argument to mathematical function ##com.oracle.truffle.r.test.builtins.TestBuiltin_floor.testFloor# -#{ trunc(1+1i); } -Error in trunc(1 + (0+1i)) : unimplemented complex function +#{ floor(c(0.2,-3.4,NA,0/0,1/0)) } +[1] 0 -4 NA NaN Inf ##com.oracle.truffle.r.test.builtins.TestBuiltin_floor.testFloor# #{ typeof(floor(42L)); } @@ -70222,12 +70222,12 @@ tracemem[0x7f818486bc90 -> 0x7f818486ba50]: [71] 0.21654883 0.01307171 ##com.oracle.truffle.r.test.builtins.TestBuiltin_trunc.testTrunc# -#{ trunc("aaa"); } -Error in trunc("aaa") : non-numeric argument to mathematical function +#if (length(grep("FastR", R.Version()$version.string)) != 1) { 1+1i } else { { trunc(1.1+1.9i); } } +[1] 1+1i ##com.oracle.truffle.r.test.builtins.TestBuiltin_trunc.testTrunc# -#{ trunc(1+1i); } -Error in trunc(1 + (0+1i)) : unimplemented complex function +#{ trunc("aaa"); } +Error in trunc("aaa") : non-numeric argument to mathematical function ##com.oracle.truffle.r.test.builtins.TestBuiltin_trunc.testTrunc# #{ typeof(trunc(42L)); } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_ceiling.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_ceiling.java index a9b34b32a401376b65596567f2a082dddabb4930..873db5bf320d576af2520e8837c3903669fc97f4 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_ceiling.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_ceiling.java @@ -62,7 +62,8 @@ public class TestBuiltin_ceiling extends TestBase { assertEval("{ ceiling(c(0.2,-3.4,NA,0/0,1/0)) }"); assertEval("{ typeof(ceiling(42L)); }"); assertEval("{ typeof(ceiling(TRUE)); }"); - assertEval("{ trunc(1+1i); }"); - assertEval("{ trunc(\"aaa\"); }"); + // not implemented for complex in GNU R + assertEvalFastR("{ ceiling(1.1+1.9i); }", "2+2i"); + assertEval("{ ceiling(\"aaa\"); }"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_floor.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_floor.java index 39355fc4872fbb9ec404f8c1c343e6f0902ccf57..3ecc9803b79a8e09070eefaa7cfd2e6c7e9bbb52 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_floor.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_floor.java @@ -67,7 +67,8 @@ public class TestBuiltin_floor extends TestBase { assertEval("{ floor(c(0.2,-3.4,NA,0/0,1/0)) }"); assertEval("{ typeof(floor(42L)); }"); assertEval("{ typeof(floor(TRUE)); }"); - assertEval("{ trunc(1+1i); }"); - assertEval("{ trunc(\"aaa\"); }"); + // not implemented for complex in GNU R + assertEvalFastR("{ floor(1.1+1.9i); }", "1+1i"); + assertEval("{ floor(\"aaa\"); }"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_trunc.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_trunc.java index f2080c629791428866d06c5d55d58a8be498e3ed..c979ed66c228c9c40eb90a3c9a4306cc12b952c2 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_trunc.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_trunc.java @@ -41,7 +41,8 @@ public class TestBuiltin_trunc extends TestBase { public void testTrunc() { assertEval("{ typeof(trunc(42L)); }"); assertEval("{ typeof(trunc(TRUE)); }"); - assertEval("{ trunc(1+1i); }"); + // not implemented for complex in GNU R + assertEvalFastR("{ trunc(1.1+1.9i); }", "1+1i"); assertEval("{ trunc(\"aaa\"); }"); } }