diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Repeat.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Repeat.java index 8b92d9861f81dc974091a8026f371a81cb12b5ce..1d3f608b04ee868b165df3dcce6e880683216bef 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Repeat.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Repeat.java @@ -35,21 +35,33 @@ import java.util.Arrays; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.dsl.TypeSystemReference; +import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.InitAttributesNode; import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.RepeatNodeGen.FastRInternalRepeatNodeGen; +import com.oracle.truffle.r.nodes.builtin.casts.fluent.PipelineBuilder; +import com.oracle.truffle.r.nodes.function.FormalArguments; +import com.oracle.truffle.r.nodes.function.call.PrepareMatchInternalArguments; +import com.oracle.truffle.r.nodes.function.call.PrepareMatchInternalArgumentsNodeGen; +import com.oracle.truffle.r.nodes.unary.CastNode; +import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RStringVector; +import com.oracle.truffle.r.runtime.data.RTypes; import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; /** * The {@code rep} builtin works as follows. @@ -71,166 +83,213 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractVector; * </ul> * </ol> */ -@RBuiltin(name = "rep", kind = PRIMITIVE, parameterNames = {"x", "times", "length.out", "each"}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@RBuiltin(name = "rep", kind = PRIMITIVE, parameterNames = {"x", "..."}, dispatch = INTERNAL_GENERIC, behavior = PURE) public abstract class Repeat extends RBuiltinNode { - private final ConditionProfile lengthOutOrTimes = ConditionProfile.createBinaryProfile(); - private final ConditionProfile oneTimeGiven = ConditionProfile.createBinaryProfile(); - private final ConditionProfile replicateOnce = ConditionProfile.createBinaryProfile(); - @Child private GetNamesAttributeNode getNames = GetNamesAttributeNode.create(); + private static final PipelineBuilder PB_TIMES; + private static final PipelineBuilder PB_LENGTH_OUT; + private static final PipelineBuilder PB_EACH; + private static final FormalArguments FORMALS; + private static final int ARG_IDX_TIMES; + private static final int ARG_IDX_LENGHT_OUT; + private static final int ARG_IDX_EACH; + + @Child private FastRInternalRepeat internalNode = FastRInternalRepeatNodeGen.create(); + @Child private CastNode castTimes = PB_TIMES.buildNode(); + @Child private CastNode castLengthOut = PB_LENGTH_OUT.buildNode(); + @Child private CastNode castEach = PB_EACH.buildNode(); + @Child private PrepareMatchInternalArguments prepareArgs = PrepareMatchInternalArgumentsNodeGen.create(FORMALS, this); @Override public Object[] getDefaultParameterValues() { - return new Object[]{RMissing.instance, 1, RRuntime.INT_NA, 1}; + return new Object[]{RMissing.instance, RArgsValuesAndNames.EMPTY}; } static { Casts casts = new Casts(Repeat.class); casts.arg("x").mustBe(abstractVectorValue(), RError.Message.ATTEMPT_TO_REPLICATE, typeName()); - casts.arg("times").defaultError(RError.Message.INVALID_ARGUMENT, "times").mustNotBeNull().asIntegerVector(); - casts.arg("length.out").asIntegerVector().shouldBe(size(1).or(size(0)), RError.Message.FIRST_ELEMENT_USED, "length.out").findFirst(RRuntime.INT_NA, - RError.Message.FIRST_ELEMENT_USED, "length.out").mustBe(intNA().or(gte(0))); - casts.arg("each").asIntegerVector().shouldBe(size(1).or(size(0)), RError.Message.FIRST_ELEMENT_USED, "each").findFirst(1, RError.Message.FIRST_ELEMENT_USED, "each").replaceNA( - 1).mustBe(gte(0)); - } - protected boolean hasNames(RAbstractVector x) { - return getNames.getNames(x) != null; - } + // prepare cast pipeline nodes for vararg matching + PB_TIMES = new PipelineBuilder("times"); + PB_TIMES.fluent().defaultError(RError.Message.INVALID_ARGUMENT, "times").mustNotBeNull().asIntegerVector(); - @Specialization(guards = {"x.getLength() == 1", "times.getLength() == 1", "each <= 1", "!hasNames(x)"}) - protected RAbstractVector repNoEachNoNamesSimple(RAbstractDoubleVector x, RAbstractIntVector times, int lengthOut, @SuppressWarnings("unused") int each) { - int t = times.getDataAt(0); - if (t < 0) { - throw error(RError.Message.INVALID_ARGUMENT, "times"); - } - int length = lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut)) ? lengthOut : t; - double[] data = new double[length]; - Arrays.fill(data, x.getDataAt(0)); - return RDataFactory.createDoubleVector(data, !RRuntime.isNA(x.getDataAt(0))); - } + PB_LENGTH_OUT = new PipelineBuilder("length.out"); + PB_LENGTH_OUT.fluent().asIntegerVector().shouldBe(size(1).or(size(0)), + RError.Message.FIRST_ELEMENT_USED, "length.out").findFirst(RRuntime.INT_NA, + RError.Message.FIRST_ELEMENT_USED, "length.out").mustBe(intNA().or(gte(0))); - @Specialization(guards = {"each > 1", "!hasNames(x)"}) - protected RAbstractVector repEachNoNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, int each) { - if (times.getLength() > 1) { - throw error(RError.Message.INVALID_ARGUMENT, "times"); - } - RAbstractVector input = handleEach(x, each); - if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { - return handleLengthOut(input, lengthOut, false); - } else { - return handleTimes(input, times, false); - } + PB_EACH = new PipelineBuilder("each"); + PB_EACH.fluent().asIntegerVector().shouldBe(size(1).or(size(0)), + RError.Message.FIRST_ELEMENT_USED, "each").findFirst(1, RError.Message.FIRST_ELEMENT_USED, + "each").replaceNA(1).mustBe(gte(0)); + + ArgumentsSignature signature = ArgumentsSignature.get("times", "length.out", "each"); + ARG_IDX_TIMES = 0; + ARG_IDX_LENGHT_OUT = 1; + ARG_IDX_EACH = 2; + FORMALS = FormalArguments.createForBuiltin(new Object[]{1, RRuntime.INT_NA, 1}, signature); } - @Specialization(guards = {"each <= 1", "!hasNames(x)"}) - protected RAbstractVector repNoEachNoNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, @SuppressWarnings("unused") int each) { - if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { - return handleLengthOut(x, lengthOut, true); - } else { - return handleTimes(x, times, true); - } + @Specialization + protected Object repeat(VirtualFrame frame, RAbstractVector x, RArgsValuesAndNames args) { + RArgsValuesAndNames margs = prepareArgs.execute(args, null); + + // cast arguments + Object times = castTimes.execute(margs.getArgument(ARG_IDX_TIMES)); + Object lengthOut = castLengthOut.execute(margs.getArgument(ARG_IDX_LENGHT_OUT)); + Object each = castEach.execute(margs.getArgument(ARG_IDX_EACH)); + + return internalNode.execute(frame, x, times, lengthOut, each); } - @Specialization(guards = {"each > 1", "hasNames(x)"}) - protected RAbstractVector repEachNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, int each, - @Cached("create()") InitAttributesNode initAttributes, - @Cached("createNames()") SetFixedAttributeNode putNames) { - if (times.getLength() > 1) { - throw error(RError.Message.INVALID_ARGUMENT, "times"); - } - RAbstractVector input = handleEach(x, each); - RStringVector names = (RStringVector) handleEach(getNames.getNames(x), each); - RVector<?> r; - if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { - names = (RStringVector) handleLengthOut(names, lengthOut, false); - r = handleLengthOut(input, lengthOut, false); - } else { - names = (RStringVector) handleTimes(names, times, false); - r = handleTimes(input, times, false); + @TypeSystemReference(RTypes.class) + abstract static class FastRInternalRepeat extends RBaseNode { + private final ConditionProfile lengthOutOrTimes = ConditionProfile.createBinaryProfile(); + private final ConditionProfile oneTimeGiven = ConditionProfile.createBinaryProfile(); + private final ConditionProfile replicateOnce = ConditionProfile.createBinaryProfile(); + + @Child private GetNamesAttributeNode getNames = GetNamesAttributeNode.create(); + + public abstract RAbstractVector execute(VirtualFrame frame, Object... args); + + protected boolean hasNames(RAbstractVector x) { + return getNames.getNames(x) != null; } - putNames.execute(initAttributes.execute(r), names); - return r; - } - @Specialization(guards = {"each <= 1", "hasNames(x)"}) - protected RAbstractVector repNoEachNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, @SuppressWarnings("unused") int each, - @Cached("create()") InitAttributesNode initAttributes, - @Cached("createNames()") SetFixedAttributeNode putNames) { - RStringVector names; - RVector<?> r; - if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { - names = (RStringVector) handleLengthOut(getNames.getNames(x), lengthOut, true); - r = handleLengthOut(x, lengthOut, true); - } else { - names = (RStringVector) handleTimes(getNames.getNames(x), times, true); - r = handleTimes(x, times, true); + @Specialization(guards = {"x.getLength() == 1", "times.getLength() == 1", "each <= 1", "!hasNames(x)"}) + protected RAbstractVector repNoEachNoNamesSimple(RAbstractDoubleVector x, RAbstractIntVector times, int lengthOut, @SuppressWarnings("unused") int each) { + int t = times.getDataAt(0); + if (t < 0) { + throw error(RError.Message.INVALID_ARGUMENT, "times"); + } + int length = lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut)) ? lengthOut : t; + double[] data = new double[length]; + Arrays.fill(data, x.getDataAt(0)); + return RDataFactory.createDoubleVector(data, !RRuntime.isNA(x.getDataAt(0))); } - putNames.execute(initAttributes.execute(r), names); - return r; - } - /** - * Prepare the input vector by replicating its elements. - */ - private static RVector<?> handleEach(RAbstractVector x, int each) { - RVector<?> r = x.createEmptySameType(x.getLength() * each, x.isComplete()); - for (int i = 0; i < x.getLength(); i++) { - for (int j = i * each; j < (i + 1) * each; j++) { - r.transferElementSameType(j, x, i); + @Specialization(guards = {"each > 1", "!hasNames(x)"}) + protected RAbstractVector repEachNoNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, int each) { + if (times.getLength() > 1) { + throw error(RError.Message.INVALID_ARGUMENT, "times"); + } + RAbstractVector input = handleEach(x, each); + if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { + return handleLengthOut(input, lengthOut, false); + } else { + return handleTimes(input, times, false); } } - return r; - } - /** - * Extend or truncate the vector to a specified length. - */ - private static RVector<?> handleLengthOut(RAbstractVector x, int lengthOut, boolean copyIfSameSize) { - if (x.getLength() == lengthOut) { - return (RVector<?>) (copyIfSameSize ? x.copy() : x); + @Specialization(guards = {"each <= 1", "!hasNames(x)"}) + protected RAbstractVector repNoEachNoNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, @SuppressWarnings("unused") int each) { + if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { + return handleLengthOut(x, lengthOut, true); + } else { + return handleTimes(x, times, true); + } } - return x.copyResized(lengthOut, false); - } - /** - * Replicate the vector a given number of times. - */ - private RVector<?> handleTimes(RAbstractVector x, RAbstractIntVector times, boolean copyIfSameSize) { - if (oneTimeGiven.profile(times.getLength() == 1)) { - // only one times value is given - final int howManyTimes = times.getDataAt(0); - if (howManyTimes < 0) { + @Specialization(guards = {"each > 1", "hasNames(x)"}) + protected RAbstractVector repEachNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, int each, + @Cached("create()") InitAttributesNode initAttributes, + @Cached("createNames()") SetFixedAttributeNode putNames) { + if (times.getLength() > 1) { throw error(RError.Message.INVALID_ARGUMENT, "times"); } - if (replicateOnce.profile(howManyTimes == 1)) { - return (RVector<?>) (copyIfSameSize ? x.copy() : x); + RAbstractVector input = handleEach(x, each); + RStringVector names = (RStringVector) handleEach(getNames.getNames(x), each); + RVector<?> r; + if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { + names = (RStringVector) handleLengthOut(names, lengthOut, false); + r = handleLengthOut(input, lengthOut, false); } else { - return x.copyResized(x.getLength() * howManyTimes, false); + names = (RStringVector) handleTimes(names, times, false); + r = handleTimes(input, times, false); } - } else { - // times is a vector with several elements - if (x.getLength() != times.getLength()) { - throw error(RError.Message.INVALID_ARGUMENT, "times"); - } - // iterate once over the times vector to determine result vector size - int resultLength = 0; - for (int i = 0; i < times.getLength(); i++) { - int t = times.getDataAt(i); - if (t < 0) { - throw error(RError.Message.INVALID_ARGUMENT, "times"); - } - resultLength += t; + putNames.execute(initAttributes.execute(r), names); + return r; + } + + @Specialization(guards = {"each <= 1", "hasNames(x)"}) + protected RAbstractVector repNoEachNames(RAbstractVector x, RAbstractIntVector times, int lengthOut, @SuppressWarnings("unused") int each, + @Cached("create()") InitAttributesNode initAttributes, + @Cached("createNames()") SetFixedAttributeNode putNames) { + RStringVector names; + RVector<?> r; + if (lengthOutOrTimes.profile(!RRuntime.isNA(lengthOut))) { + names = (RStringVector) handleLengthOut(getNames.getNames(x), lengthOut, true); + r = handleLengthOut(x, lengthOut, true); + } else { + names = (RStringVector) handleTimes(getNames.getNames(x), times, true); + r = handleTimes(x, times, true); } - // create and populate result vector - RVector<?> r = x.createEmptySameType(resultLength, x.isComplete()); - int wp = 0; // write pointer + putNames.execute(initAttributes.execute(r), names); + return r; + } + + /** + * Prepare the input vector by replicating its elements. + */ + private static RVector<?> handleEach(RAbstractVector x, int each) { + RVector<?> r = x.createEmptySameType(x.getLength() * each, x.isComplete()); for (int i = 0; i < x.getLength(); i++) { - for (int j = 0; j < times.getDataAt(i); ++j, ++wp) { - r.transferElementSameType(wp, x, i); + for (int j = i * each; j < (i + 1) * each; j++) { + r.transferElementSameType(j, x, i); } } return r; } + + /** + * Extend or truncate the vector to a specified length. + */ + private static RVector<?> handleLengthOut(RAbstractVector x, int lengthOut, boolean copyIfSameSize) { + if (x.getLength() == lengthOut) { + return (RVector<?>) (copyIfSameSize ? x.copy() : x); + } + return x.copyResized(lengthOut, false); + } + + /** + * Replicate the vector a given number of times. + */ + private RVector<?> handleTimes(RAbstractVector x, RAbstractIntVector times, boolean copyIfSameSize) { + if (oneTimeGiven.profile(times.getLength() == 1)) { + // only one times value is given + final int howManyTimes = times.getDataAt(0); + if (howManyTimes < 0) { + throw error(RError.Message.INVALID_ARGUMENT, "times"); + } + if (replicateOnce.profile(howManyTimes == 1)) { + return (RVector<?>) (copyIfSameSize ? x.copy() : x); + } else { + return x.copyResized(x.getLength() * howManyTimes, false); + } + } else { + // times is a vector with several elements + if (x.getLength() != times.getLength()) { + throw error(RError.Message.INVALID_ARGUMENT, "times"); + } + // iterate once over the times vector to determine result vector size + int resultLength = 0; + for (int i = 0; i < times.getLength(); i++) { + int t = times.getDataAt(i); + if (t < 0) { + throw error(RError.Message.INVALID_ARGUMENT, "times"); + } + resultLength += t; + } + // create and populate result vector + RVector<?> r = x.createEmptySameType(resultLength, x.isComplete()); + int wp = 0; // write pointer + for (int i = 0; i < x.getLength(); i++) { + for (int j = 0; j < times.getDataAt(i); ++j, ++wp) { + r.transferElementSameType(wp, x, i); + } + } + return r; + } + } } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java index 1cf534c2ebf3c111288f23f33305c0ec6d99042c..850535f5380023c8cd93d8b1800f8e196a7fda9c 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java @@ -211,10 +211,10 @@ public class ArgumentMatcher { } /** - * Used for the implementation of the 'UseMethod' builtin. Reorders the arguments passed into - * the called, generic function and prepares them to be passed into the specific function + * Used for matching varargs to an internally defined signature. Reorders the arguments passed + * into the called, generic function and prepares them to be passed into the specific function * - * @param target The 'Method' which is going to be 'Use'd + * @param formals The formal arguments to match to. * @param evaluatedArgs The arguments which are already in evaluated form (as they are directly * taken from the stack) * @param s3DefaultArguments default values carried over from S3 group dispatch method (e.g. @@ -223,8 +223,7 @@ public class ArgumentMatcher { * @return A Fresh {@link RArgsValuesAndNames} containing the arguments rearranged and stuffed * with default values (in the form of {@link RPromise}s where needed) */ - public static RArgsValuesAndNames matchArgumentsEvaluated(RRootNode target, RArgsValuesAndNames evaluatedArgs, S3DefaultArguments s3DefaultArguments, RBaseNode callingNode) { - FormalArguments formals = target.getFormalArguments(); + public static RArgsValuesAndNames matchArgumentsEvaluated(FormalArguments formals, RArgsValuesAndNames evaluatedArgs, S3DefaultArguments s3DefaultArguments, RBaseNode callingNode) { MatchPermutation match = permuteArguments(evaluatedArgs.getSignature(), formals.getSignature(), callingNode, index -> { throw RInternalError.unimplemented("S3Dispatch should not have arg length mismatch"); }, index -> evaluatedArgs.getSignature().getName(index), null); @@ -261,6 +260,23 @@ public class ArgumentMatcher { return new RArgsValuesAndNames(evaledArgs, formals.getSignature()); } + /** + * Used for the implementation of the 'UseMethod' builtin. Reorders the arguments passed into + * the called, generic function and prepares them to be passed into the specific function + * + * @param target The 'Method' which is going to be 'Use'd + * @param evaluatedArgs The arguments which are already in evaluated form (as they are directly + * taken from the stack) + * @param s3DefaultArguments default values carried over from S3 group dispatch method (e.g. + * from max to Summary.factor). {@code null} if there are no such arguments. + * @param callingNode The {@link Node} invoking the match + * @return A Fresh {@link RArgsValuesAndNames} containing the arguments rearranged and stuffed + * with default values (in the form of {@link RPromise}s where needed) + */ + public static RArgsValuesAndNames matchArgumentsEvaluated(RRootNode target, RArgsValuesAndNames evaluatedArgs, S3DefaultArguments s3DefaultArguments, RBaseNode callingNode) { + return matchArgumentsEvaluated(target.getFormalArguments(), evaluatedArgs, s3DefaultArguments, callingNode); + } + private static String getErrorForArgument(RNode[] suppliedArgs, ArgumentsSignature suppliedSignature, int index) { RNode node = suppliedArgs[index]; if (node instanceof VarArgNode) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/call/PrepareMatchInternalArguments.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/call/PrepareMatchInternalArguments.java new file mode 100644 index 0000000000000000000000000000000000000000..23b150ba129ea609c0df36c851469658e5b813ef --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/call/PrepareMatchInternalArguments.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.function.call; + +import java.util.Objects; + +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.r.nodes.function.ArgumentMatcher; +import com.oracle.truffle.r.nodes.function.ArgumentMatcher.MatchPermutation; +import com.oracle.truffle.r.nodes.function.FormalArguments; +import com.oracle.truffle.r.nodes.function.RCallNode; +import com.oracle.truffle.r.runtime.ArgumentsSignature; +import com.oracle.truffle.r.runtime.RArguments.S3DefaultArguments; +import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; + +/** + * Basically the same as {@link PrepareArguments} but for a specific set of internal generic + * functions using a vararg parameter but expecting a specific amount of parameters by internally + * matching them. + */ +public abstract class PrepareMatchInternalArguments extends Node { + + protected static final int CACHE_SIZE = 8; + + protected final RBaseNode callingNode; + protected final FormalArguments formals; + + protected PrepareMatchInternalArguments(FormalArguments formals, RBaseNode callingNode) { + this.callingNode = Objects.requireNonNull(callingNode); + this.formals = Objects.requireNonNull(formals); + } + + protected MatchPermutation createArguments(ArgumentsSignature supplied) { + return ArgumentMatcher.matchArguments(supplied, formals.getSignature(), callingNode, null); + } + + @Specialization(limit = "CACHE_SIZE", guards = {"cachedExplicitArgSignature == explicitArgs.getSignature()"}) + public RArgsValuesAndNames prepare(RArgsValuesAndNames explicitArgs, S3DefaultArguments s3DefaultArguments, + @SuppressWarnings("unused") @Cached("explicitArgs.getSignature()") ArgumentsSignature cachedExplicitArgSignature, + @Cached("createArguments(cachedExplicitArgSignature)") MatchPermutation permutation) { + return ArgumentMatcher.matchArgumentsEvaluated(permutation, explicitArgs.getArguments(), s3DefaultArguments, formals); + } + + @Fallback + @TruffleBoundary + public RArgsValuesAndNames prepareGeneric(RArgsValuesAndNames evaluatedArgs, S3DefaultArguments s3DefaultArguments) { + return ArgumentMatcher.matchArgumentsEvaluated(formals, evaluatedArgs, s3DefaultArguments, callingNode); + } + + /** + * Returns the argument values and corresponding signature. The signature represents the + * original call signature reordered in the same way as the arguments. For s3DefaultArguments + * motivation see {@link RCallNode#callGroupGeneric}. + */ + public abstract RArgsValuesAndNames execute(RArgsValuesAndNames evaluatedArgs, S3DefaultArguments s3DefaultArguments); +} diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 889551762012e70bc1794788efb220562e771e06..4dc6655dedaa54fe57a143f4c1d1b901d0bce1d8 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -494,6 +494,11 @@ Note: method with signature ‘A2#A1’ chosen for function ‘foo’, #{ setClass('A1', representation(a='numeric')); setMethod('length', 'A1', function(x) x@a); obj <- new('A1'); obj@a <- 10; length(obj) } [1] 10 +##com.oracle.truffle.r.test.S4.TestS4.testMethods# +#{ setClass('A2', representation(a = 'numeric')); setMethod('rep', 'A2', function(x, a, b, c) { c(x@a, a, b, c) }); setMethod('ifelse', c(yes = 'A2'), function(test, yes, no) print(test)) } +Creating a generic function for ‘ifelse’ from package ‘base’ in the global environment +[1] "ifelse" + ##com.oracle.truffle.r.test.S4.TestS4.testMethods# #{ setGeneric("gen", function(o) standardGeneric("gen")); res<-print(setGeneric("gen", function(o) standardGeneric("gen"))); removeGeneric("gen"); res } [1] "gen" diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java index 7632dfa95cbbc52e1a7f0b446c107d0b255d8bd9..5fd998ead205c3529a1bfe442c67aa78d9cfeba4 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java @@ -126,6 +126,8 @@ public class TestS4 extends TestRBase { assertEval("setGeneric('do.call', signature = c('what', 'args'))"); assertEval("{ setClass('A1', representation(a='numeric')); setMethod('length', 'A1', function(x) x@a); obj <- new('A1'); obj@a <- 10; length(obj) }"); + + assertEval("{ setClass('A2', representation(a = 'numeric')); setMethod('rep', 'A2', function(x, a, b, c) { c(x@a, a, b, c) }); setMethod('ifelse', c(yes = 'A2'), function(test, yes, no) print(test)) }"); } @Test