diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/FileFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/FileFunctions.java index 63d8b68047883562b4511b147b70663fb83c40ca..e6194ce93d9c6568b63a2b4d8fd126805d4eb31f 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/FileFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/FileFunctions.java @@ -67,9 +67,7 @@ import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.Utils; -import com.oracle.truffle.r.runtime.builtins.RBehavior; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.builtins.RBuiltinKind; import com.oracle.truffle.r.runtime.context.ConsoleHandler; import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -1018,7 +1016,7 @@ public class FileFunctions { } } - @RBuiltin(name = "file.show", kind = RBuiltinKind.INTERNAL, parameterNames = {"files", "header", "title", "delete.file", "pager"}, behavior = RBehavior.IO) + @RBuiltin(name = "file.show", kind = INTERNAL, parameterNames = {"files", "header", "title", "delete.file", "pager"}, visibility = OFF, behavior = IO) public abstract static class FileShow extends RBuiltinNode { @Override diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetTimeLimit.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetTimeLimit.java index 61e0a0edaf5a66fe8db29e29cfc1ad77115ca4b4..36c4516371412297a7aae02450ddebee22b158d0 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetTimeLimit.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetTimeLimit.java @@ -23,6 +23,7 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*; +import static com.oracle.truffle.r.runtime.RVisibility.OFF; import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; @@ -32,7 +33,7 @@ import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RNull; -@RBuiltin(name = "setTimeLimit", kind = INTERNAL, parameterNames = {"cpu", "elapsed", "transient"}, behavior = COMPLEX) +@RBuiltin(name = "setTimeLimit", kind = INTERNAL, parameterNames = {"cpu", "elapsed", "transient"}, visibility = OFF, behavior = COMPLEX) public abstract class SetTimeLimit extends RBuiltinNode { @Override diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SysFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SysFunctions.java index b020cfb2189514bbc70feaadb3186a4504db4033..3b5192cb48877a010dba5cd174f49da29a25f24c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SysFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SysFunctions.java @@ -382,7 +382,7 @@ public class SysFunctions { } } - @RBuiltin(name = "setFileTime", kind = INTERNAL, parameterNames = {"path", "time"}, behavior = IO) + @RBuiltin(name = "setFileTime", kind = INTERNAL, parameterNames = {"path", "time"}, visibility = OFF, behavior = IO) public abstract static class SysSetFileTime extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java index 9f848c51562ad640a6ac5c7331ebdd9be1c13c2b..9a435cdd16cbe7f67f6ab016b157a05ddf5f9ab0 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/AccessField.java @@ -39,7 +39,6 @@ import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode; import com.oracle.truffle.r.nodes.access.vector.ExtractListElement; import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode; -import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.ArgumentsSignature; @@ -49,6 +48,7 @@ import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RSpecialFactory; import com.oracle.truffle.r.runtime.data.RList; +import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.nodes.RNode; @@ -59,19 +59,24 @@ abstract class AccessFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { @Child private ExtractListElement extractListElement = ExtractListElement.create(); - @Specialization(guards = {"isSimpleList(list)", "isCached(list, field)", "list.getNames() != null"}) - public Object doList(RList list, String field, - @Cached("getIndex(list.getNames(), field)") int index) { + @Specialization(limit = "2", guards = {"isSimpleList(list)", "list.getNames() == cachedNames", "field == cachedField"}) + public Object doList(RList list, @SuppressWarnings("unused") String field, + @SuppressWarnings("unused") @Cached("list.getNames()") RStringVector cachedNames, + @SuppressWarnings("unused") @Cached("field") String cachedField, + @Cached("getIndex(cachedNames, field)") int index) { if (index == -1) { throw RSpecialFactory.throwFullCallNeeded(); } - updateCache(list, field); return extractListElement.execute(list, index); } @Specialization(contains = "doList", guards = {"isSimpleList(list)", "list.getNames() != null"}) - public Object doListDynamic(RList list, String field, @Cached("create()") GetNamesAttributeNode getNamesNode) { - return doList(list, field, getIndex(getNamesNode.getNames(list), field)); + public Object doListDynamic(RList list, String field) { + int index = getIndex(getNamesNode.getNames(list), field); + if (index == -1) { + throw RSpecialFactory.throwFullCallNeeded(); + } + return extractListElement.execute(list, index); } @Fallback @@ -82,6 +87,7 @@ abstract class AccessFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { } @RBuiltin(name = "$", kind = PRIMITIVE, parameterNames = {"", ""}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@TypeSystemReference(EmptyTypeSystemFlatLayout.class) public abstract class AccessField extends RBuiltinNode { @Child private ExtractVectorNode extract = ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, true); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java index 793d882e1a8121d8e29bcbdbcd993e9b78bbe6d3..cae45bca9afb8c891cd87dcfb198f9f132c5405f 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/SpecialsUtils.java @@ -22,17 +22,28 @@ */ package com.oracle.truffle.r.nodes.builtin.base.infix; -import com.oracle.truffle.api.CompilerDirectives; -import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.NodeChild; +import com.oracle.truffle.api.dsl.NodeChildren; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.dsl.TypeSystemReference; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.NodeCost; +import com.oracle.truffle.api.nodes.NodeInfo; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ValueProfile; +import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ConvertIndexNodeGen; import com.oracle.truffle.r.nodes.function.ClassHierarchyNode; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; import com.oracle.truffle.r.runtime.nodes.RNode; +import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; /** * Helper methods for implementing special calls. @@ -43,72 +54,86 @@ class SpecialsUtils { private static final String valueArgName = "value".intern(); public static boolean isCorrectUpdateSignature(ArgumentsSignature signature) { - return signature.getLength() == 3 && signature.getName(0) == null && signature.getName(1) == null && signature.getName(2) == valueArgName; + if (signature.getLength() == 3) { + return signature.getName(0) == null && signature.getName(1) == null && signature.getName(2) == valueArgName; + } else if (signature.getLength() == 4) { + return signature.getName(0) == null && signature.getName(1) == null && signature.getName(2) == null && signature.getName(3) == valueArgName; + } + return false; } /** * Common code shared between specials doing subset/subscript related operation. */ + @TypeSystemReference(EmptyTypeSystemFlatLayout.class) abstract static class SubscriptSpecialCommon extends RNode { - protected final ValueProfile vectorClassProfile = ValueProfile.createClassProfile(); + protected final boolean inReplacement; - protected boolean isValidIndex(RAbstractVector vector, int index) { - vector = vectorClassProfile.profile(vector); - return index >= 1 && index <= vector.getLength(); + protected SubscriptSpecialCommon(boolean inReplacement) { + this.inReplacement = inReplacement; } - protected boolean isValidDoubleIndex(RAbstractVector vector, double index) { - return isValidIndex(vector, toIndex(index)); + /** + * Checks whether the given (1-based) index is valid for the given vector. + */ + protected static boolean isValidIndex(RAbstractVector vector, int index) { + return index >= 1 && index <= vector.getLength(); } /** - * Note: conversion from double to an index differs in subscript and subset. + * Checks if the value is single element that can be put into a list or vector as is, + * because in the case of vectors on the LSH of update we take each element and put it into + * the RHS of the update function. */ - protected int toIndex(double index) { - if (index == 0) { - return 0; - } - int i = (int) index; - return i == 0 ? 1 : i; + protected static boolean isSingleElement(Object value) { + return value instanceof Integer || value instanceof Double || value instanceof Byte || value instanceof String; + } + } + + @TypeSystemReference(EmptyTypeSystemFlatLayout.class) + abstract static class SubscriptSpecial2Common extends SubscriptSpecialCommon { + + protected SubscriptSpecial2Common(boolean inReplacement) { + super(inReplacement); + } + + @Child private GetDimAttributeNode getDimensions = GetDimAttributeNode.create(); + + protected int matrixIndex(RAbstractVector vector, int index1, int index2) { + return index1 - 1 + ((index2 - 1) * getDimensions.getDimensions(vector)[0]); } - protected static int toIndexSubset(double index) { - return index == 0 ? 0 : (int) index; + /** + * Checks whether the given (1-based) indexes are valid for the given matrix. + */ + protected static boolean isValidIndex(RAbstractVector vector, int index1, int index2) { + int[] dimensions = vector.getDimensions(); + return dimensions != null && dimensions.length == 2 && index1 >= 1 && index1 <= dimensions[0] && index2 >= 1 && index2 <= dimensions[1]; } } /** * Common code shared between specials accessing/updating fields. */ + @TypeSystemReference(EmptyTypeSystemFlatLayout.class) abstract static class ListFieldSpecialBase extends RNode { - @CompilationFinal private String cachedField; - @CompilationFinal private RStringVector cachedNames; + @Child private ClassHierarchyNode hierarchyNode = ClassHierarchyNode.create(); @Child protected GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); - protected final void updateCache(RList list, String field) { - if (cachedField == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - cachedField = field; - cachedNames = getNamesNode.getNames(list); - } - } - protected final boolean isSimpleList(RList list) { return hierarchyNode.execute(list) == null; } - protected final boolean isCached(RList list, String field) { - return cachedField == null || (cachedField == field && getNamesNode.getNames(list) == cachedNames); - } - protected static int getIndex(RStringVector names, String field) { - int fieldHash = field.hashCode(); - for (int i = 0; i < names.getLength(); i++) { - String current = names.getDataAt(i); - if (current == field || hashCodeEquals(current, fieldHash) && contentsEquals(current, field)) { - return i; + if (names != null) { + int fieldHash = field.hashCode(); + for (int i = 0; i < names.getLength(); i++) { + String current = names.getDataAt(i); + if (current == field || hashCodeEquals(current, fieldHash) && contentsEquals(current, field)) { + return i; + } } } return -1; @@ -124,4 +149,79 @@ class SpecialsUtils { return current.hashCode() == fieldHash; } } + + @NodeInfo(cost = NodeCost.NONE) + public static final class ProfiledValue extends RBaseNode { + + private final ValueProfile profile = ValueProfile.createClassProfile(); + + @Child private RNode delegate; + + protected ProfiledValue(RNode delegate) { + this.delegate = delegate; + } + + public Object execute(VirtualFrame frame) { + return profile.profile(delegate.execute(frame)); + } + + @Override + protected RSyntaxNode getRSyntaxNode() { + return delegate.asRSyntaxNode(); + } + } + + @NodeInfo(cost = NodeCost.NONE) + @NodeChildren({@NodeChild(value = "delegate", type = RNode.class)}) + @TypeSystemReference(EmptyTypeSystemFlatLayout.class) + public abstract static class ConvertIndex extends RNode { + + private final boolean isSubset; + private final ConditionProfile zeroProfile; + + ConvertIndex(boolean isSubset) { + this.isSubset = isSubset; + this.zeroProfile = isSubset ? null : ConditionProfile.createBinaryProfile(); + } + + protected abstract RNode getDelegate(); + + @Specialization + protected static int convertInteger(int value) { + return value; + } + + @Specialization + protected int convertDouble(double value) { + // Conversion from double to an index differs in subscript and subset. + int intValue = (int) value; + if (isSubset) { + return intValue; + } else { + return zeroProfile.profile(intValue == 0) ? (value == 0 ? 0 : 1) : intValue; + } + } + + @Specialization(contains = {"convertInteger", "convertDouble"}) + protected Object convert(Object value) { + return value; + } + + @Override + protected RSyntaxNode getRSyntaxNode() { + return getDelegate().asRSyntaxNode(); + } + } + + public static ProfiledValue profile(RNode value) { + return new ProfiledValue(value); + } + + public static ConvertIndex convertSubscript(RNode value) { + return ConvertIndexNodeGen.create(false, value); + } + + public static ConvertIndex convertSubset(RNode value) { + return ConvertIndexNodeGen.create(true, value); + } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java index ed74d6b51b07cdd7a9495ded353dfff1c8343d90..4c8fe6acd36096bbcb2bcf80383deb43a80b5156 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subscript.java @@ -22,6 +22,8 @@ */ package com.oracle.truffle.r.nodes.builtin.base.infix; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubscript; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; @@ -29,14 +31,17 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.NodeChild; +import com.oracle.truffle.api.dsl.NodeChildren; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.TypeSystemReference; import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode; import com.oracle.truffle.r.nodes.access.vector.ExtractListElement; import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecial2Common; import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecialCommon; import com.oracle.truffle.r.nodes.function.ClassHierarchyNode; import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen; @@ -50,6 +55,7 @@ import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RLogical; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.RTypesFlatLayout; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; @@ -60,10 +66,13 @@ import com.oracle.truffle.r.runtime.nodes.RNode; /** * Subscript code for vectors minus list is the same as subset code, this class allows sharing it. */ -@NodeChild(value = "arguments", type = RNode[].class) -@TypeSystemReference(EmptyTypeSystemFlatLayout.class) +@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index", type = ConvertIndex.class)}) abstract class SubscriptSpecialBase extends SubscriptSpecialCommon { + protected SubscriptSpecialBase(boolean inReplacement) { + super(inReplacement); + } + @Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false); protected boolean simpleVector(RAbstractVector vector) { @@ -72,58 +81,108 @@ abstract class SubscriptSpecialBase extends SubscriptSpecialCommon { @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"}) protected int access(RAbstractIntVector vector, int index) { - return vectorClassProfile.profile(vector).getDataAt(index - 1); + return vector.getDataAt(index - 1); } @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"}) protected double access(RAbstractDoubleVector vector, int index) { - return vectorClassProfile.profile(vector).getDataAt(index - 1); + return vector.getDataAt(index - 1); } @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"}) protected String access(RAbstractStringVector vector, int index) { - return vectorClassProfile.profile(vector).getDataAt(index - 1); + return vector.getDataAt(index - 1); } - @Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"}) - protected int access(RAbstractIntVector vector, double index) { - return vectorClassProfile.profile(vector).getDataAt(toIndex(index) - 1); + @SuppressWarnings("unused") + @Fallback + protected static Object access(Object vector, Object index) { + throw RSpecialFactory.throwFullCallNeeded(); } +} + +/** + * Subscript code for matrices minus list is the same as subset code, this class allows sharing it. + */ +@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index1", type = ConvertIndex.class), @NodeChild(value = "index2", type = ConvertIndex.class)}) +abstract class SubscriptSpecial2Base extends SubscriptSpecial2Common { - @Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"}) - protected double access(RAbstractDoubleVector vector, double index) { - return vectorClassProfile.profile(vector).getDataAt(toIndex(index) - 1); + protected SubscriptSpecial2Base(boolean inReplacement) { + super(inReplacement); } - @Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"}) - protected String access(RAbstractStringVector vector, double index) { - return vectorClassProfile.profile(vector).getDataAt(toIndex(index) - 1); + @Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false); + + protected abstract ProfiledValue getVector(); + + protected abstract ConvertIndex getIndex1(); + + protected abstract ConvertIndex getIndex2(); + + protected boolean simpleVector(RAbstractVector vector) { + return classHierarchy.execute(vector) == null; + } + + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"}) + protected int access(RAbstractIntVector vector, int index1, int index2) { + return vector.getDataAt(matrixIndex(vector, index1, index2)); + } + + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"}) + protected double access(RAbstractDoubleVector vector, int index1, int index2) { + return vector.getDataAt(matrixIndex(vector, index1, index2)); + } + + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"}) + protected String access(RAbstractStringVector vector, int index1, int index2) { + return vector.getDataAt(matrixIndex(vector, index1, index2)); } @SuppressWarnings("unused") @Fallback - protected static Object access(Object vector, Object index) { + protected static Object access(Object vector, Object index1, Object index2) { throw RSpecialFactory.throwFullCallNeeded(); } } -@TypeSystemReference(EmptyTypeSystemFlatLayout.class) abstract class SubscriptSpecial extends SubscriptSpecialBase { - @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"}) + protected SubscriptSpecial(boolean inReplacement) { + super(inReplacement); + } + + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)", "!inReplacement"}) protected static Object access(RList vector, int index, @Cached("create()") ExtractListElement extract) { return extract.execute(vector, index - 1); } - @Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"}) - protected Object access(RList vector, double index, + protected static ExtractVectorNode createAccess() { + return ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, false); + } + + @Specialization(guards = {"simpleVector(vector)", "!inReplacement"}) + protected static Object access(VirtualFrame frame, RAbstractVector vector, Object index, + @Cached("createAccess()") ExtractVectorNode extract) { + return extract.apply(frame, vector, new Object[]{index}, RRuntime.LOGICAL_TRUE, RLogical.TRUE); + } +} + +abstract class SubscriptSpecial2 extends SubscriptSpecial2Base { + + protected SubscriptSpecial2(boolean inReplacement) { + super(inReplacement); + } + + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)", "!inReplacement"}) + protected Object access(RList vector, int index1, int index2, @Cached("create()") ExtractListElement extract) { - return extract.execute(vector, toIndex(index) - 1); + return extract.execute(vector, matrixIndex(vector, index1, index2)); } } @RBuiltin(name = "[[", kind = PRIMITIVE, parameterNames = {"x", "...", "exact", "drop"}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@TypeSystemReference(RTypesFlatLayout.class) public abstract class Subscript extends RBuiltinNode { @RBuiltin(name = ".subset2", kind = PRIMITIVE, parameterNames = {"x", "...", "exact", "drop"}, behavior = PURE) @@ -131,8 +190,15 @@ public abstract class Subscript extends RBuiltinNode { // same implementation as "[[", with different dispatch } - public static RNode special(ArgumentsSignature signature, RNode[] arguments, @SuppressWarnings("unused") boolean inReplacement) { - return signature.getNonNullCount() == 0 && arguments.length == 2 ? SubscriptSpecialNodeGen.create(arguments) : null; + public static RNode special(ArgumentsSignature signature, RNode[] arguments, boolean inReplacement) { + if (signature.getNonNullCount() == 0) { + if (arguments.length == 2) { + return SubscriptSpecialNodeGen.create(inReplacement, profile(arguments[0]), convertSubscript(arguments[1])); + } else if (arguments.length == 3) { + return SubscriptSpecial2NodeGen.create(inReplacement, profile(arguments[0]), convertSubscript(arguments[1]), convertSubscript(arguments[2])); + } + } + return null; } @Child private ExtractVectorNode extractNode = ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, false); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subset.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subset.java index 9d693de22845ef0dba043d41824103eeda08015e..05954ea72e696cff2514a33da821e631656ee65b 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subset.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/Subset.java @@ -22,6 +22,8 @@ */ package com.oracle.truffle.r.nodes.builtin.base.infix; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubset; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; @@ -37,7 +39,10 @@ import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue; import com.oracle.truffle.r.runtime.ArgumentsSignature; +import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -52,36 +57,63 @@ import com.oracle.truffle.r.runtime.nodes.RNode; * Subset special only handles single element integer/double index. In the case of list, we need to * create the actual list otherwise we just return the primitive type. */ -@TypeSystemReference(EmptyTypeSystemFlatLayout.class) abstract class SubsetSpecial extends SubscriptSpecialBase { @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); - @Override - protected boolean simpleVector(RAbstractVector vector) { - vector = vectorClassProfile.profile(vector); - return super.simpleVector(vector) && getNamesNode.getNames(vector) == null; + protected SubsetSpecial(boolean inReplacement) { + super(inReplacement); } @Override - protected int toIndex(double index) { - return toIndexSubset(index); + protected boolean simpleVector(RAbstractVector vector) { + return super.simpleVector(vector) && getNamesNode.getNames(vector) == null; } - @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"}) + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)", "!inReplacement"}) protected static RList access(RList vector, int index, @Cached("create()") ExtractListElement extract) { return RDataFactory.createList(new Object[]{extract.execute(vector, index - 1)}); } - @Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"}) - protected static RList access(RList vector, double index, + protected static ExtractVectorNode createAccess() { + return ExtractVectorNode.create(ElementAccessMode.SUBSET, false); + } + + @Specialization(guards = {"simpleVector(vector)", "!inReplacement"}) + protected static Object access(VirtualFrame frame, RAbstractVector vector, Object index, + @Cached("createAccess()") ExtractVectorNode extract) { + return extract.apply(frame, vector, new Object[]{index}, RRuntime.LOGICAL_TRUE, RLogical.TRUE); + } +} + +/** + * Subset special only handles single element integer/double index. In the case of list, we need to + * create the actual list otherwise we just return the primitive type. + */ +@TypeSystemReference(EmptyTypeSystemFlatLayout.class) +abstract class SubsetSpecial2 extends SubscriptSpecial2Base { + + @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); + + protected SubsetSpecial2(boolean inReplacement) { + super(inReplacement); + } + + @Override + protected boolean simpleVector(RAbstractVector vector) { + return super.simpleVector(vector) && getNamesNode.getNames(vector) == null; + } + + @Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)", "!inReplacement"}) + protected RList access(RList vector, int index1, int index2, @Cached("create()") ExtractListElement extract) { - return RDataFactory.createList(new Object[]{extract.execute(vector, toIndexSubset(index) - 1)}); + return RDataFactory.createList(new Object[]{extract.execute(vector, matrixIndex(vector, index1, index2))}); } } @RBuiltin(name = "[", kind = PRIMITIVE, parameterNames = {"x", "...", "drop"}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@TypeSystemReference(EmptyTypeSystemFlatLayout.class) public abstract class Subset extends RBuiltinNode { @RBuiltin(name = ".subset", kind = PRIMITIVE, parameterNames = {"", "...", "drop"}, behavior = PURE) @@ -89,14 +121,17 @@ public abstract class Subset extends RBuiltinNode { // same implementation as "[", with different dispatch } - public static RNode special(ArgumentsSignature signature, RNode[] arguments, boolean inReplacement) { - boolean correctSignature = signature.getNonNullCount() == 0 && arguments.length == 2; - if (!correctSignature) { - return null; + public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) { + if (signature.getNonNullCount() == 0 && (args.length == 2 || args.length == 3)) { + ProfiledValue profiledVector = profile(args[0]); + ConvertIndex index = convertSubset(args[1]); + if (args.length == 2) { + return SubsetSpecialNodeGen.create(inReplacement, profiledVector, index); + } else { + return SubsetSpecial2NodeGen.create(inReplacement, profiledVector, index, convertSubset(args[2])); + } } - // Subset adds support for lists returning newly created list, which cannot work when used - // in replacement, because we need the reference to the existing (materialized) list element - return inReplacement ? SubscriptSpecialBaseNodeGen.create(arguments) : SubsetSpecialNodeGen.create(arguments); + return null; } @Child private ExtractVectorNode extractNode = ExtractVectorNode.create(ElementAccessMode.SUBSET, false); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java index ebbd01a056c1e63bb5e0436186690bb062b07a2c..edb89de721d58a14448121bc9a634559c5c12258 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java @@ -32,8 +32,10 @@ import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.dsl.TypeSystemReference; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode; import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; @@ -48,6 +50,7 @@ import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RSpecialFactory; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.nodes.RNode; @@ -65,13 +68,14 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { return value != RNull.instance && !(value instanceof RList); } - @Specialization(guards = {"isSimpleList(list)", "!list.isShared()", "isCached(list, field)", "list.getNames() != null", "isNotRNullRList(value)"}) - public RList doList(RList list, String field, Object value, - @Cached("getIndex(list.getNames(), field)") int index) { + @Specialization(limit = "2", guards = {"isSimpleList(list)", "!list.isShared()", "list.getNames() == cachedNames", "field == cachedField", "isNotRNullRList(value)"}) + public Object doList(RList list, @SuppressWarnings("unused") String field, Object value, + @SuppressWarnings("unused") @Cached("list.getNames()") RStringVector cachedNames, + @SuppressWarnings("unused") @Cached("field") String cachedField, + @Cached("getIndex(cachedNames, field)") int index) { if (index == -1) { throw RSpecialFactory.throwFullCallNeeded(value); } - updateCache(list, field); Object sharedValue = value; // share only when necessary: if (list.getDataAt(index) != value) { @@ -83,7 +87,17 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { @Specialization(contains = "doList", guards = {"isSimpleList(list)", "!list.isShared()", "list.getNames() != null", "isNotRNullRList(value)"}) public RList doListDynamic(RList list, String field, Object value) { - return doList(list, field, value, getIndex(getNamesNode.getNames(list), field)); + int index = getIndex(getNamesNode.getNames(list), field); + if (index == -1) { + throw RSpecialFactory.throwFullCallNeeded(value); + } + Object sharedValue = value; + // share only when necessary: + if (list.getDataAt(index) != value) { + sharedValue = getShareObjectNode().execute(value); + } + list.setElement(index, sharedValue); + return list; } @SuppressWarnings("unused") @@ -102,9 +116,10 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase { } @RBuiltin(name = "$<-", kind = PRIMITIVE, parameterNames = {"", "", "value"}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@TypeSystemReference(EmptyTypeSystemFlatLayout.class) public abstract class UpdateField extends RBuiltinNode { - @Child private ReplaceVectorNode extract = ReplaceVectorNode.create(ElementAccessMode.SUBSCRIPT, true); + @Child private ReplaceVectorNode update = ReplaceVectorNode.create(ElementAccessMode.SUBSCRIPT, true); @Child private CastListNode castList; private final ConditionProfile coerceList = ConditionProfile.createBinaryProfile(); @@ -121,7 +136,7 @@ public abstract class UpdateField extends RBuiltinNode { @Specialization protected Object update(VirtualFrame frame, Object container, String field, Object value) { Object list = coerceList.profile(container instanceof RAbstractListVector) ? container : coerceList(container); - return extract.apply(frame, list, new Object[]{field}, value); + return update.apply(frame, list, new Object[]{field}, value); } private Object coerceList(Object vector) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java index 5a68b3998b0e09edf9ddc0c2374f005f7b0675f9..4825f4efd34fd8b96eec20750d6ae33623f69e9d 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java @@ -22,6 +22,8 @@ */ package com.oracle.truffle.r.nodes.builtin.base.infix; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubscript; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; @@ -30,6 +32,7 @@ import java.util.Arrays; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.NodeChild; +import com.oracle.truffle.api.dsl.NodeChildren; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.TypeSystemReference; import com.oracle.truffle.api.frame.VirtualFrame; @@ -38,6 +41,9 @@ import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode; import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecial2Common; import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecialCommon; import com.oracle.truffle.r.nodes.function.ClassHierarchyNode; import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen; @@ -54,25 +60,21 @@ import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.nodes.RNode; import com.oracle.truffle.r.runtime.ops.na.NACheck; -@NodeChild(value = "arguments", type = RNode[].class) -@TypeSystemReference(EmptyTypeSystemFlatLayout.class) +@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index", type = ConvertIndex.class), @NodeChild(value = "value", type = RNode.class)}) abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon { + + protected UpdateSubscriptSpecial(boolean inReplacement) { + super(inReplacement); + } + @Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false); + private final NACheck naCheck = NACheck.create(); protected boolean simple(Object vector) { return classHierarchy.execute(vector) == null; } - /** - * Checks if the value is single element that can be put into a list or vector as is, because in - * the case of vectors on the LSH of update we take each element and put it into the RHS of the - * update function. - */ - protected static boolean isSingleElement(Object value) { - return value instanceof Integer || value instanceof Double || value instanceof Byte || value instanceof String; - } - @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"}) protected RIntVector set(RIntVector vector, int index, int value) { return vector.updateDataAt(index - 1, value, naCheck); @@ -94,53 +96,86 @@ abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon { return list; } - @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"}) - protected RIntVector setDoubleIndex(RIntVector vector, double index, int value) { - return vector.updateDataAt(toIndex(index) - 1, value, naCheck); - } - @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"}) protected RDoubleVector setDoubleIntIndexIntValue(RDoubleVector vector, int index, int value) { - return vector.updateDataAt(toIndex(index) - 1, value, naCheck); + return vector.updateDataAt(index - 1, value, naCheck); } - @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"}) - protected RDoubleVector setDoubleIndexIntValue(RDoubleVector vector, double index, int value) { - return vector.updateDataAt(toIndex(index) - 1, value, naCheck); + @SuppressWarnings("unused") + @Fallback + protected static Object setFallback(Object vector, Object index, Object value) { + throw RSpecialFactory.throwFullCallNeeded(value); } +} - @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"}) - protected RDoubleVector setDoubleIndex(RDoubleVector vector, double index, double value) { - return vector.updateDataAt(toIndex(index) - 1, value, naCheck); +@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index1", type = ConvertIndex.class), @NodeChild(value = "index2", type = ConvertIndex.class), + @NodeChild(value = "value", type = RNode.class)}) +abstract class UpdateSubscriptSpecial2 extends SubscriptSpecial2Common { + + protected UpdateSubscriptSpecial2(boolean inReplacement) { + super(inReplacement); } - @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"}) - protected RStringVector setDoubleIndex(RStringVector vector, double index, String value) { - return vector.updateDataAt(toIndex(index) - 1, value, naCheck); + @Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false); + + private final NACheck naCheck = NACheck.create(); + + protected boolean simple(Object vector) { + return classHierarchy.execute(vector) == null; + } + + @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"}) + protected RIntVector set(RIntVector vector, int index1, int index2, int value) { + return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck); + } + + @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"}) + protected RDoubleVector set(RDoubleVector vector, int index1, int index2, double value) { + return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck); + } + + @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"}) + protected RStringVector set(RStringVector vector, int index1, int index2, String value) { + return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck); } - @Specialization(guards = {"simple(list)", "!list.isShared()", "isValidDoubleIndex(list, index)", "isSingleElement(value)"}) - protected Object setDoubleIndex(RList list, double index, Object value) { - list.setDataAt(list.getInternalStore(), toIndex(index) - 1, value); + @Specialization(guards = {"simple(list)", "!list.isShared()", "isValidIndex(list, index1, index2)", "isSingleElement(value)"}) + protected Object set(RList list, int index1, int index2, Object value) { + list.setDataAt(list.getInternalStore(), matrixIndex(list, index1, index2), value); return list; } + @Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"}) + protected RDoubleVector setDoubleIntIndexIntValue(RDoubleVector vector, int index1, int index2, int value) { + return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck); + } + @SuppressWarnings("unused") @Fallback - protected static Object setFallback(Object vector, Object index, Object value) { + protected static Object setFallback(Object vector, Object index1, Object index2, Object value) { throw RSpecialFactory.throwFullCallNeeded(value); } } @RBuiltin(name = "[[<-", kind = PRIMITIVE, parameterNames = {"", "..."}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@TypeSystemReference(EmptyTypeSystemFlatLayout.class) public abstract class UpdateSubscript extends RBuiltinNode { @Child private ReplaceVectorNode replaceNode = ReplaceVectorNode.create(ElementAccessMode.SUBSCRIPT, false); private final ConditionProfile argsLengthLargerThanOneProfile = ConditionProfile.createBinaryProfile(); - public static RNode special(ArgumentsSignature signature, RNode[] arguments, @SuppressWarnings("unused") boolean inReplacement) { - return SpecialsUtils.isCorrectUpdateSignature(signature) && arguments.length == 3 ? UpdateSubscriptSpecialNodeGen.create(arguments) : null; + public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) { + if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) { + ProfiledValue vector = profile(args[0]); + ConvertIndex index = convertSubscript(args[1]); + if (args.length == 3) { + return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]); + } else { + return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertSubscript(args[2]), args[3]); + } + } + return null; } @Specialization(guards = "!args.isEmpty()") diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubset.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubset.java index 5c92d85bf3ca0df6b123e20f9cec2b15573b631a..820e2ebcf5835bbc9f616f6c5c1611fde51ffe86 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubset.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubset.java @@ -22,13 +22,14 @@ */ package com.oracle.truffle.r.nodes.builtin.base.infix; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubset; +import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import java.util.Arrays; -import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.TypeSystemReference; import com.oracle.truffle.api.frame.VirtualFrame; @@ -37,6 +38,8 @@ import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode; import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex; +import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.builtins.RBuiltin; @@ -44,24 +47,24 @@ import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.nodes.RNode; -@NodeChild(value = "arguments", type = RNode[].class) -@TypeSystemReference(EmptyTypeSystemFlatLayout.class) -abstract class UpdateSubsetSpecial extends UpdateSubscriptSpecial { - - @Override - protected int toIndex(double index) { - return toIndexSubset(index); - } -} - @RBuiltin(name = "[<-", kind = PRIMITIVE, parameterNames = {"", "..."}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@TypeSystemReference(EmptyTypeSystemFlatLayout.class) public abstract class UpdateSubset extends RBuiltinNode { @Child private ReplaceVectorNode replaceNode = ReplaceVectorNode.create(ElementAccessMode.SUBSET, false); private final ConditionProfile argsLengthLargerThanOneProfile = ConditionProfile.createBinaryProfile(); - public static RNode special(ArgumentsSignature signature, RNode[] arguments, @SuppressWarnings("unused") boolean inReplacement) { - return SpecialsUtils.isCorrectUpdateSignature(signature) && arguments.length == 3 ? UpdateSubsetSpecialNodeGen.create(arguments) : null; + public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) { + if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) { + ProfiledValue vector = profile(args[0]); + ConvertIndex index = convertSubset(args[1]); + if (args.length == 3) { + return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]); + } else { + return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertSubset(args[2]), args[3]); + } + } + return null; } @Specialization(guards = "!args.isEmpty()") diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java index b72e022b321a7d8fa76f41a506f474f6bd84ab90..b92c005198d339616de71ffa80264e1ff7848959 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/SpecialCallTest.java @@ -28,11 +28,6 @@ import org.junit.Test; import com.oracle.truffle.api.RootCallTarget; import com.oracle.truffle.api.source.Source; import com.oracle.truffle.r.engine.TruffleRLanguage; -import com.oracle.truffle.r.nodes.access.WriteVariableSyntaxNode; -import com.oracle.truffle.r.nodes.control.BlockNode; -import com.oracle.truffle.r.nodes.control.ReplacementDispatchNode; -import com.oracle.truffle.r.nodes.function.RCallNode; -import com.oracle.truffle.r.nodes.function.RCallSpecialNode; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.FastROptions; import com.oracle.truffle.r.runtime.RError; @@ -59,12 +54,23 @@ public class SpecialCallTest extends TestBase { @Override protected Void visit(RSyntaxCall element) { - if (element instanceof RCallSpecialNode) { - special++; - } else if (element instanceof RCallNode) { - normal++; - } else { - assert element instanceof ReplacementDispatchNode || element instanceof WriteVariableSyntaxNode || element instanceof BlockNode : "unexpected node while testing"; + switch (element.getClass().getSimpleName()) { + case "SpecialReplacementNode": + case "SpecialVoidReplacementNode": + case "RCallSpecialNode": + special++; + break; + case "RCallNodeGen": + case "GenericReplacementNode": + normal++; + break; + case "ReplacementDispatchNode": + case "WriteVariableSyntaxNode": + case "BlockNode": + // ignored + break; + default: + throw new AssertionError("unexpected class: " + element.getClass().getSimpleName()); } accept(element.getSyntaxLHS()); for (RSyntaxElement arg : element.getSyntaxArguments()) { @@ -158,81 +164,133 @@ public class SpecialCallTest extends TestBase { assertCallCounts("1 + 1", 1, 0, 1, 0); assertCallCounts("1 + 1 * 2 + 4", 3, 0, 3, 0); - assertCallCounts("{ a <- 1; b <- 2; a + b }", 1, 0, 1, 0); - assertCallCounts("{ a <- 1; b <- 2; c <- 3; a + b * 2 * c}", 3, 0, 3, 0); + assertCallCounts("{ a <- 1; b <- 2 }", "a + b", 1, 0, 1, 0); + assertCallCounts("{ a <- 1; b <- 2; c <- 3 }", "a + b * 2 * c", 3, 0, 3, 0); - assertCallCounts("{ a <- data.frame(a=1); b <- 2; c <- 3; a + b * 2 * c}", 3, 1, 2, 2); - assertCallCounts("{ a <- 1; b <- data.frame(a=1); c <- 3; a + b * 2 * c}", 3, 1, 0, 4); + assertCallCounts("{ a <- data.frame(a=1); b <- 2; c <- 3 }", "a + b * 2 * c", 3, 0, 2, 1); + assertCallCounts("{ a <- 1; b <- data.frame(a=1); c <- 3 }", "a + b * 2 * c", 3, 0, 0, 3); assertCallCounts("1 %*% 1", 0, 1, 0, 1); } @Test public void testSubset() { - assertCallCounts("{ a <- 1:10; a[1] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[2] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[4] }", 1, 1, 1, 1); - assertCallCounts("{ a <- list(c(1,2,3,4),2,3); a[1] }", 1, 2, 1, 2); - - assertCallCounts("{ a <- c(1,2,3,4); a[0.1] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[5] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[0] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[-1] }", 0, 3, 0, 3); // "-1" is a unary expression - assertCallCounts("{ a <- c(1,2,3,4); b <- -1; a[b] }", 1, 2, 0, 3); - assertCallCounts("{ a <- c(1,2,3,4); a[NA_integer_] }", 1, 1, 0, 2); + assertCallCounts("a <- 1:10", "a[1]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[2]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[4]", 1, 0, 1, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[1]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[0.1]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[5]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[0]", 1, 0, 1, 0); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[b]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[NA_integer_]", 1, 0, 1, 0); + + assertCallCounts("a <- c(1,2,3,4)", "a[-1]", 0, 2, 0, 2); // "-1" is a unary expression + assertCallCounts("a <- c(1,2,3,4)", "a[drop=T, 1]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[drop=F, 1]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[1, drop=F]", 0, 1, 0, 1); } @Test public void testSubscript() { - assertCallCounts("{ a <- 1:10; a[[1]] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[[2]] }", 1, 1, 1, 1); - assertCallCounts("{ a <- c(1,2,3,4); a[[4]] }", 1, 1, 1, 1); - assertCallCounts("{ a <- list(c(1,2,3,4),2,3); a[[1]] }", 1, 2, 1, 2); - assertCallCounts("{ a <- list(a=c(1,2,3,4),2,3); a[[1]] }", 1, 2, 1, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[[0.1]] }", 1, 1, 1, 1); - - assertCallCounts("{ a <- c(1,2,3,4); a[[5]] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); a[[0]] }", 1, 1, 0, 2); - assertCallCounts("{ a <- c(1,2,3,4); b <- -1; a[[b]] }", 1, 2, 0, 3); - assertCallCounts("{ a <- c(1,2,3,4); a[[NA_integer_]] }", 1, 1, 0, 2); + assertCallCounts("a <- 1:10", "a[[1]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[2]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[4]]", 1, 0, 1, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[[1]]", 1, 0, 1, 0); + assertCallCounts("a <- list(a=c(1,2,3,4),2,3)", "a[[1]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[0.1]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[5]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[0]]", 1, 0, 1, 0); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[[b]]", 1, 0, 1, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[NA_integer_]]", 1, 0, 1, 0); + + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=T, 1]]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=F, 1]]", 0, 1, 0, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[1, drop=F]]", 0, 1, 0, 1); } - private static void assertCallCounts(String str, int initialSpecialCount, int initialNormalCount, int finalSpecialCount, int finalNormalCount) { + @Test + public void testUpdateSubset() { + assertCallCounts("a <- 1:10", "a[1] <- 1", 1, 0, 1, 1); // sequence + assertCallCounts("a <- c(1,2,3,4)", "a[2] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[4] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[1] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[0.1] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[5] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[0] <- 1", 1, 0, 1, 1); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[b] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[NA_integer_] <- 1", 1, 0, 1, 1); + + assertCallCounts("a <- c(1,2,3,4)", "a[-1] <- 1", 0, 2, 0, 3); // "-1" is a unary expression + assertCallCounts("a <- c(1,2,3,4)", "a[drop=T, 1] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[drop=F, 1] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[1, drop=F] <- 1", 0, 1, 0, 2); + } + + @Test + public void testUpdateSubscript() { + assertCallCounts("a <- 1:10", "a[[1]] <- 1", 1, 0, 1, 1); // sequence + assertCallCounts("a <- c(1,2,3,4)", "a[[2]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[4]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- list(a=c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[0.1]] <- 1", 1, 0, 2, 0); + assertCallCounts("a <- c(1,2,3,4)", "a[[5]] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[0]] <- 1", 1, 0, 1, 1); + assertCallCounts("{ a <- c(1,2,3,4); b <- -1 }", "a[[b]] <- 1", 1, 0, 1, 1); + assertCallCounts("a <- c(1,2,3,4)", "a[[NA_integer_]] <- 1", 1, 0, 1, 1); + + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=T, 1]] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[[drop=F, 1]] <- 1", 0, 1, 0, 2); + assertCallCounts("a <- c(1,2,3,4)", "a[[1, drop=F]] <- 1", 0, 1, 0, 2); + } + + private static void assertCallCounts(String test, int initialSpecialCount, int initialNormalCount, int finalSpecialCount, int finalNormalCount) { + assertCallCounts("{}", test, initialSpecialCount, initialNormalCount, finalSpecialCount, finalNormalCount); + } + + private static void assertCallCounts(String setup, String test, int initialSpecialCount, int initialNormalCount, int finalSpecialCount, int finalNormalCount) { if (!FastROptions.UseSpecials.getBooleanValue()) { return; } - Source source = Source.newBuilder(str).mimeType(TruffleRLanguage.MIME).name("test").build(); + Source setupSource = Source.newBuilder(setup).mimeType(TruffleRLanguage.MIME).name("test").build(); + Source testSource = Source.newBuilder(test).mimeType(TruffleRLanguage.MIME).name("test").build(); - RExpression expression = testVMContext.getThisEngine().parse(source); - assert expression.getLength() == 1; - RootCallTarget callTarget = testVMContext.getThisEngine().makePromiseCallTarget(((RLanguage) expression.getDataAt(0)).getRep().asRSyntaxNode().asRNode(), "test"); + RExpression setupExpression = testVMContext.getThisEngine().parse(setupSource); + RExpression testExpression = testVMContext.getThisEngine().parse(testSource); + assert setupExpression.getLength() == 1; + assert testExpression.getLength() == 1; + RootCallTarget setupCallTarget = testVMContext.getThisEngine().makePromiseCallTarget(((RLanguage) setupExpression.getDataAt(0)).getRep().asRSyntaxNode().asRNode(), "test"); + RootCallTarget testCallTarget = testVMContext.getThisEngine().makePromiseCallTarget(((RLanguage) testExpression.getDataAt(0)).getRep().asRSyntaxNode().asRNode(), "test"); try { - CountCallsVisitor count1 = new CountCallsVisitor(callTarget); - Assert.assertEquals("initial special call count '" + str + "': ", initialSpecialCount, count1.special); - Assert.assertEquals("initial normal call count '" + str + "': ", initialNormalCount, count1.normal); + CountCallsVisitor count1 = new CountCallsVisitor(testCallTarget); + Assert.assertEquals("initial special call count '" + setup + "; " + test + "': ", initialSpecialCount, count1.special); + Assert.assertEquals("initial normal call count '" + setup + "; " + test + "': ", initialNormalCount, count1.normal); try { - callTarget.call(REnvironment.globalEnv().getFrame()); + setupCallTarget.call(REnvironment.globalEnv().getFrame()); + testCallTarget.call(REnvironment.globalEnv().getFrame()); } catch (RError e) { // ignore } - CountCallsVisitor count2 = new CountCallsVisitor(callTarget); - Assert.assertEquals("special call count after first call '" + str + "': ", finalSpecialCount, count2.special); - Assert.assertEquals("normal call count after first call '" + str + "': ", finalNormalCount, count2.normal); + CountCallsVisitor count2 = new CountCallsVisitor(testCallTarget); + Assert.assertEquals("special call count after first call '" + setup + "; " + test + "': ", finalSpecialCount, count2.special); + Assert.assertEquals("normal call count after first call '" + setup + "; " + test + "': ", finalNormalCount, count2.normal); try { - callTarget.call(REnvironment.globalEnv().getFrame()); + setupCallTarget.call(REnvironment.globalEnv().getFrame()); + testCallTarget.call(REnvironment.globalEnv().getFrame()); } catch (RError e) { // ignore } - CountCallsVisitor count3 = new CountCallsVisitor(callTarget); - Assert.assertEquals("special call count after second call '" + str + "': ", finalSpecialCount, count3.special); - Assert.assertEquals("normal call count after second call '" + str + "': ", finalNormalCount, count3.normal); + CountCallsVisitor count3 = new CountCallsVisitor(testCallTarget); + Assert.assertEquals("special call count after second call '" + setup + "; " + test + "': ", finalSpecialCount, count3.special); + Assert.assertEquals("normal call count after second call '" + setup + "; " + test + "': ", finalNormalCount, count3.normal); } catch (AssertionError e) { - new PrintCallsVisitor().print(callTarget); + new PrintCallsVisitor().print(testCallTarget); throw e; } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/ExtractListElement.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/ExtractListElement.java index 5a6b09782203e748ec27868e9d68f7dd2a4c14d6..485ca001bcc8c4dc4266b01c4640968f12a74d5e 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/ExtractListElement.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/ExtractListElement.java @@ -57,7 +57,8 @@ public abstract class ExtractListElement extends Node { } @Specialization - protected Object doList(RListBase list, int index, @Cached("create()") UpdateShareableChildValueNode updateStateNode) { + protected Object doList(RListBase list, int index, + @Cached("create()") UpdateShareableChildValueNode updateStateNode) { Object element = list.getDataAt(index); return updateStateNode.updateState(list, element); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java index 936e2babd89f1a40e0485779285835a07c622653..8fea3922499d8129309a55a96d5a685689993d7a 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java @@ -69,6 +69,10 @@ abstract class ReplacementNode extends OperatorNode { // Note: if specials are turned off in FastR, onlySpecials will never be true boolean createSpecial = hasOnlySpecialCalls(calls); if (createSpecial) { + /* + * This assumes that whenever there's a special call for the "extract", there's also a + * special call for "replace". + */ if (isVoid) { return new SpecialVoidReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex); } else { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java index 830bf51331a185cf2cc207e1bc7c22577f1aace2..9834af4dee1da8e4bd95dc84352c836dfcb0d40e 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java @@ -197,16 +197,17 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode } RNode[] localArguments = new RNode[arguments.length]; for (int i = 0; i < arguments.length; i++) { + RSyntaxNode arg = arguments[i]; if (inReplace && contains(ignoredArguments, i)) { - localArguments[i] = arguments[i].asRNode(); + localArguments[i] = arg.asRNode(); } else { - if (arguments[i] instanceof RSyntaxLookup) { - localArguments[i] = new PeekLocalVariableNode(((RSyntaxLookup) arguments[i]).getIdentifier()); - } else if (arguments[i] instanceof RSyntaxConstant) { - localArguments[i] = RContext.getASTBuilder().process(arguments[i]).asRNode(); + if (arg instanceof RSyntaxLookup) { + localArguments[i] = new PeekLocalVariableNode(((RSyntaxLookup) arg).getIdentifier()); + } else if (arg instanceof RSyntaxConstant) { + localArguments[i] = RContext.getASTBuilder().process(arg).asRNode(); } else { - assert arguments[i] instanceof RCallSpecialNode; - localArguments[i] = arguments[i].asRNode(); + assert arg instanceof RCallSpecialNode; + localArguments[i] = arg.asRNode(); } } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RTypesFlatLayout.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RTypesFlatLayout.java new file mode 100644 index 0000000000000000000000000000000000000000..ed769f7b79bb5da49e260d32d43b057406d81318 --- /dev/null +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RTypesFlatLayout.java @@ -0,0 +1,303 @@ +/* + * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.runtime.data; + +import com.oracle.truffle.api.dsl.ImplicitCast; +import com.oracle.truffle.api.dsl.TypeCast; +import com.oracle.truffle.api.dsl.TypeCheck; +import com.oracle.truffle.api.dsl.TypeSystem; +import com.oracle.truffle.api.dsl.internal.DSLOptions; +import com.oracle.truffle.api.dsl.internal.DSLOptions.DSLGenerator; +import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.nodes.RNode; + +/** + * Whenever you add a type {@code T} to the list below, make sure a corresponding {@code executeT()} + * method is added to {@link RNode}, a {@code typeof} method is added to {@code TypeoNode} and a + * {@code print} method added to {code PrettyPrinterNode}. + * + * @see RNode + */ +@TypeSystem({byte.class, int.class, double.class}) +@DSLOptions(defaultGenerator = DSLGenerator.FLAT) +public class RTypesFlatLayout { + + @TypeCheck(RNull.class) + public static boolean isRNull(Object value) { + return value == RNull.instance; + } + + @TypeCast(RNull.class) + @SuppressWarnings("unused") + public static RNull asRNull(Object value) { + return RNull.instance; + } + + @TypeCheck(RMissing.class) + public static boolean isRMissing(Object value) { + return value == RMissing.instance; + } + + @TypeCast(RMissing.class) + @SuppressWarnings("unused") + public static RMissing asRMissing(Object value) { + return RMissing.instance; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(int value) { + return RDataFactory.createIntVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(double value) { + return RDataFactory.createDoubleVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RRaw value) { + return RDataFactory.createRawVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(byte value) { + return RDataFactory.createLogicalVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RComplex value) { + return RDataFactory.createComplexVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(String value) { + return RDataFactory.createStringVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RIntVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RDoubleVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RLogicalVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RComplexVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RRawVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RStringVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RIntSequence vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RDoubleSequence vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RList vector) { + return vector; + } + + @ImplicitCast + public static RAbstractContainer toAbstractContainer(RAbstractVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(int value) { + return RDataFactory.createIntVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(double value) { + return RDataFactory.createDoubleVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RRaw value) { + return RDataFactory.createRawVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(byte value) { + return RDataFactory.createLogicalVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RComplex value) { + return RDataFactory.createComplexVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(String value) { + return RDataFactory.createStringVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RIntVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RDoubleVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RLogicalVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RComplexVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RRawVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RStringVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RIntSequence vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RDoubleSequence vector) { + return vector; + } + + @ImplicitCast + public static RAbstractVector toAbstractVector(RList vector) { + return vector; + } + + @ImplicitCast + public static RAbstractIntVector toAbstractIntVector(int value) { + return RDataFactory.createIntVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractIntVector toAbstractIntVector(RIntSequence vector) { + return vector; + } + + @ImplicitCast + public static RAbstractIntVector toAbstractIntVector(RIntVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractDoubleVector toAbstractDoubleVector(double value) { + return RDataFactory.createDoubleVectorFromScalar(value); + } + + @ImplicitCast + public static RAbstractDoubleVector toAbstractDoubleVector(RDoubleSequence vector) { + return vector; + } + + @ImplicitCast + public static RAbstractDoubleVector toAbstractDoubleVector(RDoubleVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractComplexVector toAbstractComplexVector(RComplexVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractComplexVector toAbstractComplexVector(RComplex vector) { + return RDataFactory.createComplexVectorFromScalar(vector); + } + + @ImplicitCast + public static RAbstractLogicalVector toAbstractLogicalVector(byte vector) { + return RDataFactory.createLogicalVectorFromScalar(vector); + } + + @ImplicitCast + public static RAbstractLogicalVector toAbstractLogicalVector(RLogicalVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractRawVector toAbstractRawVector(RRaw vector) { + return RDataFactory.createRawVectorFromScalar(vector); + } + + @ImplicitCast + public static RAbstractRawVector toAbstractRawVector(RRawVector vector) { + return vector; + } + + @ImplicitCast + public static RAbstractStringVector toAbstractStringVector(String vector) { + return RDataFactory.createStringVectorFromScalar(vector); + } + + @ImplicitCast + public static RAbstractStringVector toAbstractStringVector(RStringVector vector) { + return vector; + } + + @ImplicitCast + public static RMissing toRMissing(@SuppressWarnings("unused") REmpty empty) { + return RMissing.instance; + } +}