Skip to content
Snippets Groups Projects
Commit 52869ac6 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

special calls for matrix operations, thread safety and new DSL layout in...

special calls for matrix operations, thread safety and new DSL layout in field/subset/subscript functions
parent f46a244c
No related branches found
No related tags found
No related merge requests found
Showing
with 403 additions and 138 deletions
......@@ -39,7 +39,6 @@ import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout;
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ExtractListElement;
import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
......@@ -49,6 +48,7 @@ import com.oracle.truffle.r.runtime.RType;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.builtins.RSpecialFactory;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractListVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
......@@ -59,19 +59,24 @@ abstract class AccessFieldSpecial extends SpecialsUtils.ListFieldSpecialBase {
@Child private ExtractListElement extractListElement = ExtractListElement.create();
@Specialization(guards = {"isSimpleList(list)", "isCached(list, field)", "list.getNames() != null"})
public Object doList(RList list, String field,
@Cached("getIndex(list.getNames(), field)") int index) {
@Specialization(limit = "2", guards = {"isSimpleList(list)", "list.getNames() == cachedNames", "field == cachedField"})
public Object doList(RList list, @SuppressWarnings("unused") String field,
@SuppressWarnings("unused") @Cached("list.getNames()") RStringVector cachedNames,
@SuppressWarnings("unused") @Cached("field") String cachedField,
@Cached("getIndex(cachedNames, field)") int index) {
if (index == -1) {
throw RSpecialFactory.throwFullCallNeeded();
}
updateCache(list, field);
return extractListElement.execute(list, index);
}
@Specialization(contains = "doList", guards = {"isSimpleList(list)", "list.getNames() != null"})
public Object doListDynamic(RList list, String field, @Cached("create()") GetNamesAttributeNode getNamesNode) {
return doList(list, field, getIndex(getNamesNode.getNames(list), field));
public Object doListDynamic(RList list, String field) {
int index = getIndex(getNamesNode.getNames(list), field);
if (index == -1) {
throw RSpecialFactory.throwFullCallNeeded();
}
return extractListElement.execute(list, index);
}
@Fallback
......@@ -82,6 +87,7 @@ abstract class AccessFieldSpecial extends SpecialsUtils.ListFieldSpecialBase {
}
@RBuiltin(name = "$", kind = PRIMITIVE, parameterNames = {"", ""}, dispatch = INTERNAL_GENERIC, behavior = PURE)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
public abstract class AccessField extends RBuiltinNode {
@Child private ExtractVectorNode extract = ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, true);
......
......@@ -22,17 +22,28 @@
*/
package com.oracle.truffle.r.nodes.builtin.base.infix;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.NodeCost;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ConvertIndexNodeGen;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
/**
* Helper methods for implementing special calls.
......@@ -43,72 +54,86 @@ class SpecialsUtils {
private static final String valueArgName = "value".intern();
public static boolean isCorrectUpdateSignature(ArgumentsSignature signature) {
return signature.getLength() == 3 && signature.getName(0) == null && signature.getName(1) == null && signature.getName(2) == valueArgName;
if (signature.getLength() == 3) {
return signature.getName(0) == null && signature.getName(1) == null && signature.getName(2) == valueArgName;
} else if (signature.getLength() == 4) {
return signature.getName(0) == null && signature.getName(1) == null && signature.getName(2) == null && signature.getName(3) == valueArgName;
}
return false;
}
/**
* Common code shared between specials doing subset/subscript related operation.
*/
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
abstract static class SubscriptSpecialCommon extends RNode {
protected final ValueProfile vectorClassProfile = ValueProfile.createClassProfile();
protected final boolean inReplacement;
protected boolean isValidIndex(RAbstractVector vector, int index) {
vector = vectorClassProfile.profile(vector);
return index >= 1 && index <= vector.getLength();
protected SubscriptSpecialCommon(boolean inReplacement) {
this.inReplacement = inReplacement;
}
protected boolean isValidDoubleIndex(RAbstractVector vector, double index) {
return isValidIndex(vector, toIndex(index));
/**
* Checks whether the given (1-based) index is valid for the given vector.
*/
protected static boolean isValidIndex(RAbstractVector vector, int index) {
return index >= 1 && index <= vector.getLength();
}
/**
* Note: conversion from double to an index differs in subscript and subset.
* Checks if the value is single element that can be put into a list or vector as is,
* because in the case of vectors on the LSH of update we take each element and put it into
* the RHS of the update function.
*/
protected int toIndex(double index) {
if (index == 0) {
return 0;
}
int i = (int) index;
return i == 0 ? 1 : i;
protected static boolean isSingleElement(Object value) {
return value instanceof Integer || value instanceof Double || value instanceof Byte || value instanceof String;
}
}
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
abstract static class SubscriptSpecial2Common extends SubscriptSpecialCommon {
protected SubscriptSpecial2Common(boolean inReplacement) {
super(inReplacement);
}
@Child private GetDimAttributeNode getDimensions = GetDimAttributeNode.create();
protected int matrixIndex(RAbstractVector vector, int index1, int index2) {
return index1 - 1 + ((index2 - 1) * getDimensions.getDimensions(vector)[0]);
}
protected static int toIndexSubset(double index) {
return index == 0 ? 0 : (int) index;
/**
* Checks whether the given (1-based) indexes are valid for the given matrix.
*/
protected static boolean isValidIndex(RAbstractVector vector, int index1, int index2) {
int[] dimensions = vector.getDimensions();
return dimensions != null && dimensions.length == 2 && index1 >= 1 && index1 <= dimensions[0] && index2 >= 1 && index2 <= dimensions[1];
}
}
/**
* Common code shared between specials accessing/updating fields.
*/
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
abstract static class ListFieldSpecialBase extends RNode {
@CompilationFinal private String cachedField;
@CompilationFinal private RStringVector cachedNames;
@Child private ClassHierarchyNode hierarchyNode = ClassHierarchyNode.create();
@Child protected GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create();
protected final void updateCache(RList list, String field) {
if (cachedField == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
cachedField = field;
cachedNames = getNamesNode.getNames(list);
}
}
protected final boolean isSimpleList(RList list) {
return hierarchyNode.execute(list) == null;
}
protected final boolean isCached(RList list, String field) {
return cachedField == null || (cachedField == field && getNamesNode.getNames(list) == cachedNames);
}
protected static int getIndex(RStringVector names, String field) {
int fieldHash = field.hashCode();
for (int i = 0; i < names.getLength(); i++) {
String current = names.getDataAt(i);
if (current == field || hashCodeEquals(current, fieldHash) && contentsEquals(current, field)) {
return i;
if (names != null) {
int fieldHash = field.hashCode();
for (int i = 0; i < names.getLength(); i++) {
String current = names.getDataAt(i);
if (current == field || hashCodeEquals(current, fieldHash) && contentsEquals(current, field)) {
return i;
}
}
}
return -1;
......@@ -124,4 +149,79 @@ class SpecialsUtils {
return current.hashCode() == fieldHash;
}
}
@NodeInfo(cost = NodeCost.NONE)
public static final class ProfiledValue extends RBaseNode {
private final ValueProfile profile = ValueProfile.createClassProfile();
@Child private RNode delegate;
protected ProfiledValue(RNode delegate) {
this.delegate = delegate;
}
public Object execute(VirtualFrame frame) {
return profile.profile(delegate.execute(frame));
}
@Override
protected RSyntaxNode getRSyntaxNode() {
return delegate.asRSyntaxNode();
}
}
@NodeInfo(cost = NodeCost.NONE)
@NodeChildren({@NodeChild(value = "delegate", type = RNode.class)})
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
public abstract static class ConvertIndex extends RNode {
private final boolean isSubset;
private final ConditionProfile zeroProfile;
ConvertIndex(boolean isSubset) {
this.isSubset = isSubset;
this.zeroProfile = isSubset ? null : ConditionProfile.createBinaryProfile();
}
protected abstract RNode getDelegate();
@Specialization
protected static int convertInteger(int value) {
return value;
}
@Specialization
protected int convertDouble(double value) {
// Conversion from double to an index differs in subscript and subset.
int intValue = (int) value;
if (isSubset) {
return intValue;
} else {
return zeroProfile.profile(intValue == 0) ? (value == 0 ? 0 : 1) : intValue;
}
}
@Specialization(contains = {"convertInteger", "convertDouble"})
protected Object convert(Object value) {
return value;
}
@Override
protected RSyntaxNode getRSyntaxNode() {
return getDelegate().asRSyntaxNode();
}
}
public static ProfiledValue profile(RNode value) {
return new ProfiledValue(value);
}
public static ConvertIndex convertSubscript(RNode value) {
return ConvertIndexNodeGen.create(false, value);
}
public static ConvertIndex convertSubset(RNode value) {
return ConvertIndexNodeGen.create(true, value);
}
}
......@@ -22,6 +22,8 @@
*/
package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubscript;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
......@@ -29,14 +31,17 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout;
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ExtractListElement;
import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecial2Common;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecialCommon;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
......@@ -50,6 +55,7 @@ import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RLogical;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RTypesFlatLayout;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
......@@ -60,10 +66,13 @@ import com.oracle.truffle.r.runtime.nodes.RNode;
/**
* Subscript code for vectors minus list is the same as subset code, this class allows sharing it.
*/
@NodeChild(value = "arguments", type = RNode[].class)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index", type = ConvertIndex.class)})
abstract class SubscriptSpecialBase extends SubscriptSpecialCommon {
protected SubscriptSpecialBase(boolean inReplacement) {
super(inReplacement);
}
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
protected boolean simpleVector(RAbstractVector vector) {
......@@ -72,58 +81,108 @@ abstract class SubscriptSpecialBase extends SubscriptSpecialCommon {
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"})
protected int access(RAbstractIntVector vector, int index) {
return vectorClassProfile.profile(vector).getDataAt(index - 1);
return vector.getDataAt(index - 1);
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"})
protected double access(RAbstractDoubleVector vector, int index) {
return vectorClassProfile.profile(vector).getDataAt(index - 1);
return vector.getDataAt(index - 1);
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"})
protected String access(RAbstractStringVector vector, int index) {
return vectorClassProfile.profile(vector).getDataAt(index - 1);
return vector.getDataAt(index - 1);
}
@Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"})
protected int access(RAbstractIntVector vector, double index) {
return vectorClassProfile.profile(vector).getDataAt(toIndex(index) - 1);
@SuppressWarnings("unused")
@Fallback
protected static Object access(Object vector, Object index) {
throw RSpecialFactory.throwFullCallNeeded();
}
}
/**
* Subscript code for matrices minus list is the same as subset code, this class allows sharing it.
*/
@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index1", type = ConvertIndex.class), @NodeChild(value = "index2", type = ConvertIndex.class)})
abstract class SubscriptSpecial2Base extends SubscriptSpecial2Common {
@Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"})
protected double access(RAbstractDoubleVector vector, double index) {
return vectorClassProfile.profile(vector).getDataAt(toIndex(index) - 1);
protected SubscriptSpecial2Base(boolean inReplacement) {
super(inReplacement);
}
@Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"})
protected String access(RAbstractStringVector vector, double index) {
return vectorClassProfile.profile(vector).getDataAt(toIndex(index) - 1);
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
protected abstract ProfiledValue getVector();
protected abstract ConvertIndex getIndex1();
protected abstract ConvertIndex getIndex2();
protected boolean simpleVector(RAbstractVector vector) {
return classHierarchy.execute(vector) == null;
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"})
protected int access(RAbstractIntVector vector, int index1, int index2) {
return vector.getDataAt(matrixIndex(vector, index1, index2));
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"})
protected double access(RAbstractDoubleVector vector, int index1, int index2) {
return vector.getDataAt(matrixIndex(vector, index1, index2));
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"})
protected String access(RAbstractStringVector vector, int index1, int index2) {
return vector.getDataAt(matrixIndex(vector, index1, index2));
}
@SuppressWarnings("unused")
@Fallback
protected static Object access(Object vector, Object index) {
protected static Object access(Object vector, Object index1, Object index2) {
throw RSpecialFactory.throwFullCallNeeded();
}
}
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
abstract class SubscriptSpecial extends SubscriptSpecialBase {
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"})
protected SubscriptSpecial(boolean inReplacement) {
super(inReplacement);
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)", "!inReplacement"})
protected static Object access(RList vector, int index,
@Cached("create()") ExtractListElement extract) {
return extract.execute(vector, index - 1);
}
@Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"})
protected Object access(RList vector, double index,
protected static ExtractVectorNode createAccess() {
return ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, false);
}
@Specialization(guards = {"simpleVector(vector)", "!inReplacement"})
protected static Object access(VirtualFrame frame, RAbstractVector vector, Object index,
@Cached("createAccess()") ExtractVectorNode extract) {
return extract.apply(frame, vector, new Object[]{index}, RRuntime.LOGICAL_TRUE, RLogical.TRUE);
}
}
abstract class SubscriptSpecial2 extends SubscriptSpecial2Base {
protected SubscriptSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)", "!inReplacement"})
protected Object access(RList vector, int index1, int index2,
@Cached("create()") ExtractListElement extract) {
return extract.execute(vector, toIndex(index) - 1);
return extract.execute(vector, matrixIndex(vector, index1, index2));
}
}
@RBuiltin(name = "[[", kind = PRIMITIVE, parameterNames = {"x", "...", "exact", "drop"}, dispatch = INTERNAL_GENERIC, behavior = PURE)
@TypeSystemReference(RTypesFlatLayout.class)
public abstract class Subscript extends RBuiltinNode {
@RBuiltin(name = ".subset2", kind = PRIMITIVE, parameterNames = {"x", "...", "exact", "drop"}, behavior = PURE)
......@@ -131,8 +190,15 @@ public abstract class Subscript extends RBuiltinNode {
// same implementation as "[[", with different dispatch
}
public static RNode special(ArgumentsSignature signature, RNode[] arguments, @SuppressWarnings("unused") boolean inReplacement) {
return signature.getNonNullCount() == 0 && arguments.length == 2 ? SubscriptSpecialNodeGen.create(arguments) : null;
public static RNode special(ArgumentsSignature signature, RNode[] arguments, boolean inReplacement) {
if (signature.getNonNullCount() == 0) {
if (arguments.length == 2) {
return SubscriptSpecialNodeGen.create(inReplacement, profile(arguments[0]), convertSubscript(arguments[1]));
} else if (arguments.length == 3) {
return SubscriptSpecial2NodeGen.create(inReplacement, profile(arguments[0]), convertSubscript(arguments[1]), convertSubscript(arguments[2]));
}
}
return null;
}
@Child private ExtractVectorNode extractNode = ExtractVectorNode.create(ElementAccessMode.SUBSCRIPT, false);
......
......@@ -22,6 +22,8 @@
*/
package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubset;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
......@@ -37,7 +39,10 @@ import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames;
import com.oracle.truffle.r.runtime.data.RDataFactory;
......@@ -52,36 +57,63 @@ import com.oracle.truffle.r.runtime.nodes.RNode;
* Subset special only handles single element integer/double index. In the case of list, we need to
* create the actual list otherwise we just return the primitive type.
*/
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
abstract class SubsetSpecial extends SubscriptSpecialBase {
@Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create();
@Override
protected boolean simpleVector(RAbstractVector vector) {
vector = vectorClassProfile.profile(vector);
return super.simpleVector(vector) && getNamesNode.getNames(vector) == null;
protected SubsetSpecial(boolean inReplacement) {
super(inReplacement);
}
@Override
protected int toIndex(double index) {
return toIndexSubset(index);
protected boolean simpleVector(RAbstractVector vector) {
return super.simpleVector(vector) && getNamesNode.getNames(vector) == null;
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"})
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)", "!inReplacement"})
protected static RList access(RList vector, int index,
@Cached("create()") ExtractListElement extract) {
return RDataFactory.createList(new Object[]{extract.execute(vector, index - 1)});
}
@Specialization(guards = {"simpleVector(vector)", "isValidDoubleIndex(vector, index)"})
protected static RList access(RList vector, double index,
protected static ExtractVectorNode createAccess() {
return ExtractVectorNode.create(ElementAccessMode.SUBSET, false);
}
@Specialization(guards = {"simpleVector(vector)", "!inReplacement"})
protected static Object access(VirtualFrame frame, RAbstractVector vector, Object index,
@Cached("createAccess()") ExtractVectorNode extract) {
return extract.apply(frame, vector, new Object[]{index}, RRuntime.LOGICAL_TRUE, RLogical.TRUE);
}
}
/**
* Subset special only handles single element integer/double index. In the case of list, we need to
* create the actual list otherwise we just return the primitive type.
*/
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
abstract class SubsetSpecial2 extends SubscriptSpecial2Base {
@Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create();
protected SubsetSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Override
protected boolean simpleVector(RAbstractVector vector) {
return super.simpleVector(vector) && getNamesNode.getNames(vector) == null;
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)", "!inReplacement"})
protected RList access(RList vector, int index1, int index2,
@Cached("create()") ExtractListElement extract) {
return RDataFactory.createList(new Object[]{extract.execute(vector, toIndexSubset(index) - 1)});
return RDataFactory.createList(new Object[]{extract.execute(vector, matrixIndex(vector, index1, index2))});
}
}
@RBuiltin(name = "[", kind = PRIMITIVE, parameterNames = {"x", "...", "drop"}, dispatch = INTERNAL_GENERIC, behavior = PURE)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
public abstract class Subset extends RBuiltinNode {
@RBuiltin(name = ".subset", kind = PRIMITIVE, parameterNames = {"", "...", "drop"}, behavior = PURE)
......@@ -89,14 +121,17 @@ public abstract class Subset extends RBuiltinNode {
// same implementation as "[", with different dispatch
}
public static RNode special(ArgumentsSignature signature, RNode[] arguments, boolean inReplacement) {
boolean correctSignature = signature.getNonNullCount() == 0 && arguments.length == 2;
if (!correctSignature) {
return null;
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (signature.getNonNullCount() == 0 && (args.length == 2 || args.length == 3)) {
ProfiledValue profiledVector = profile(args[0]);
ConvertIndex index = convertSubset(args[1]);
if (args.length == 2) {
return SubsetSpecialNodeGen.create(inReplacement, profiledVector, index);
} else {
return SubsetSpecial2NodeGen.create(inReplacement, profiledVector, index, convertSubset(args[2]));
}
}
// Subset adds support for lists returning newly created list, which cannot work when used
// in replacement, because we need the reference to the existing (materialized) list element
return inReplacement ? SubscriptSpecialBaseNodeGen.create(arguments) : SubsetSpecialNodeGen.create(arguments);
return null;
}
@Child private ExtractVectorNode extractNode = ExtractVectorNode.create(ElementAccessMode.SUBSET, false);
......
......@@ -32,8 +32,10 @@ import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout;
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
......@@ -48,6 +50,7 @@ import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.builtins.RSpecialFactory;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractListVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
......@@ -65,13 +68,14 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase {
return value != RNull.instance && !(value instanceof RList);
}
@Specialization(guards = {"isSimpleList(list)", "!list.isShared()", "isCached(list, field)", "list.getNames() != null", "isNotRNullRList(value)"})
public RList doList(RList list, String field, Object value,
@Cached("getIndex(list.getNames(), field)") int index) {
@Specialization(limit = "2", guards = {"isSimpleList(list)", "!list.isShared()", "list.getNames() == cachedNames", "field == cachedField", "isNotRNullRList(value)"})
public Object doList(RList list, @SuppressWarnings("unused") String field, Object value,
@SuppressWarnings("unused") @Cached("list.getNames()") RStringVector cachedNames,
@SuppressWarnings("unused") @Cached("field") String cachedField,
@Cached("getIndex(cachedNames, field)") int index) {
if (index == -1) {
throw RSpecialFactory.throwFullCallNeeded(value);
}
updateCache(list, field);
Object sharedValue = value;
// share only when necessary:
if (list.getDataAt(index) != value) {
......@@ -83,7 +87,17 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase {
@Specialization(contains = "doList", guards = {"isSimpleList(list)", "!list.isShared()", "list.getNames() != null", "isNotRNullRList(value)"})
public RList doListDynamic(RList list, String field, Object value) {
return doList(list, field, value, getIndex(getNamesNode.getNames(list), field));
int index = getIndex(getNamesNode.getNames(list), field);
if (index == -1) {
throw RSpecialFactory.throwFullCallNeeded(value);
}
Object sharedValue = value;
// share only when necessary:
if (list.getDataAt(index) != value) {
sharedValue = getShareObjectNode().execute(value);
}
list.setElement(index, sharedValue);
return list;
}
@SuppressWarnings("unused")
......@@ -102,9 +116,10 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase {
}
@RBuiltin(name = "$<-", kind = PRIMITIVE, parameterNames = {"", "", "value"}, dispatch = INTERNAL_GENERIC, behavior = PURE)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
public abstract class UpdateField extends RBuiltinNode {
@Child private ReplaceVectorNode extract = ReplaceVectorNode.create(ElementAccessMode.SUBSCRIPT, true);
@Child private ReplaceVectorNode update = ReplaceVectorNode.create(ElementAccessMode.SUBSCRIPT, true);
@Child private CastListNode castList;
private final ConditionProfile coerceList = ConditionProfile.createBinaryProfile();
......@@ -121,7 +136,7 @@ public abstract class UpdateField extends RBuiltinNode {
@Specialization
protected Object update(VirtualFrame frame, Object container, String field, Object value) {
Object list = coerceList.profile(container instanceof RAbstractListVector) ? container : coerceList(container);
return extract.apply(frame, list, new Object[]{field}, value);
return update.apply(frame, list, new Object[]{field}, value);
}
private Object coerceList(Object vector) {
......
......@@ -22,6 +22,8 @@
*/
package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubscript;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
......@@ -30,6 +32,7 @@ import java.util.Arrays;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.frame.VirtualFrame;
......@@ -38,6 +41,9 @@ import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout;
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecial2Common;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecialCommon;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
......@@ -54,25 +60,21 @@ import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
@NodeChild(value = "arguments", type = RNode[].class)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index", type = ConvertIndex.class), @NodeChild(value = "value", type = RNode.class)})
abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon {
protected UpdateSubscriptSpecial(boolean inReplacement) {
super(inReplacement);
}
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
private final NACheck naCheck = NACheck.create();
protected boolean simple(Object vector) {
return classHierarchy.execute(vector) == null;
}
/**
* Checks if the value is single element that can be put into a list or vector as is, because in
* the case of vectors on the LSH of update we take each element and put it into the RHS of the
* update function.
*/
protected static boolean isSingleElement(Object value) {
return value instanceof Integer || value instanceof Double || value instanceof Byte || value instanceof String;
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
protected RIntVector set(RIntVector vector, int index, int value) {
return vector.updateDataAt(index - 1, value, naCheck);
......@@ -94,53 +96,86 @@ abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon {
return list;
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"})
protected RIntVector setDoubleIndex(RIntVector vector, double index, int value) {
return vector.updateDataAt(toIndex(index) - 1, value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
protected RDoubleVector setDoubleIntIndexIntValue(RDoubleVector vector, int index, int value) {
return vector.updateDataAt(toIndex(index) - 1, value, naCheck);
return vector.updateDataAt(index - 1, value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"})
protected RDoubleVector setDoubleIndexIntValue(RDoubleVector vector, double index, int value) {
return vector.updateDataAt(toIndex(index) - 1, value, naCheck);
@SuppressWarnings("unused")
@Fallback
protected static Object setFallback(Object vector, Object index, Object value) {
throw RSpecialFactory.throwFullCallNeeded(value);
}
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"})
protected RDoubleVector setDoubleIndex(RDoubleVector vector, double index, double value) {
return vector.updateDataAt(toIndex(index) - 1, value, naCheck);
@NodeChildren({@NodeChild(value = "vector", type = ProfiledValue.class), @NodeChild(value = "index1", type = ConvertIndex.class), @NodeChild(value = "index2", type = ConvertIndex.class),
@NodeChild(value = "value", type = RNode.class)})
abstract class UpdateSubscriptSpecial2 extends SubscriptSpecial2Common {
protected UpdateSubscriptSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidDoubleIndex(vector, index)"})
protected RStringVector setDoubleIndex(RStringVector vector, double index, String value) {
return vector.updateDataAt(toIndex(index) - 1, value, naCheck);
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
private final NACheck naCheck = NACheck.create();
protected boolean simple(Object vector) {
return classHierarchy.execute(vector) == null;
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RIntVector set(RIntVector vector, int index1, int index2, int value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RDoubleVector set(RDoubleVector vector, int index1, int index2, double value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RStringVector set(RStringVector vector, int index1, int index2, String value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
@Specialization(guards = {"simple(list)", "!list.isShared()", "isValidDoubleIndex(list, index)", "isSingleElement(value)"})
protected Object setDoubleIndex(RList list, double index, Object value) {
list.setDataAt(list.getInternalStore(), toIndex(index) - 1, value);
@Specialization(guards = {"simple(list)", "!list.isShared()", "isValidIndex(list, index1, index2)", "isSingleElement(value)"})
protected Object set(RList list, int index1, int index2, Object value) {
list.setDataAt(list.getInternalStore(), matrixIndex(list, index1, index2), value);
return list;
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RDoubleVector setDoubleIntIndexIntValue(RDoubleVector vector, int index1, int index2, int value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
@SuppressWarnings("unused")
@Fallback
protected static Object setFallback(Object vector, Object index, Object value) {
protected static Object setFallback(Object vector, Object index1, Object index2, Object value) {
throw RSpecialFactory.throwFullCallNeeded(value);
}
}
@RBuiltin(name = "[[<-", kind = PRIMITIVE, parameterNames = {"", "..."}, dispatch = INTERNAL_GENERIC, behavior = PURE)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
public abstract class UpdateSubscript extends RBuiltinNode {
@Child private ReplaceVectorNode replaceNode = ReplaceVectorNode.create(ElementAccessMode.SUBSCRIPT, false);
private final ConditionProfile argsLengthLargerThanOneProfile = ConditionProfile.createBinaryProfile();
public static RNode special(ArgumentsSignature signature, RNode[] arguments, @SuppressWarnings("unused") boolean inReplacement) {
return SpecialsUtils.isCorrectUpdateSignature(signature) && arguments.length == 3 ? UpdateSubscriptSpecialNodeGen.create(arguments) : null;
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) {
ProfiledValue vector = profile(args[0]);
ConvertIndex index = convertSubscript(args[1]);
if (args.length == 3) {
return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]);
} else {
return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertSubscript(args[2]), args[3]);
}
}
return null;
}
@Specialization(guards = "!args.isEmpty()")
......
......@@ -22,13 +22,14 @@
*/
package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubset;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.Arrays;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.frame.VirtualFrame;
......@@ -37,6 +38,8 @@ import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout;
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
......@@ -44,24 +47,24 @@ import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.nodes.RNode;
@NodeChild(value = "arguments", type = RNode[].class)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
abstract class UpdateSubsetSpecial extends UpdateSubscriptSpecial {
@Override
protected int toIndex(double index) {
return toIndexSubset(index);
}
}
@RBuiltin(name = "[<-", kind = PRIMITIVE, parameterNames = {"", "..."}, dispatch = INTERNAL_GENERIC, behavior = PURE)
@TypeSystemReference(EmptyTypeSystemFlatLayout.class)
public abstract class UpdateSubset extends RBuiltinNode {
@Child private ReplaceVectorNode replaceNode = ReplaceVectorNode.create(ElementAccessMode.SUBSET, false);
private final ConditionProfile argsLengthLargerThanOneProfile = ConditionProfile.createBinaryProfile();
public static RNode special(ArgumentsSignature signature, RNode[] arguments, @SuppressWarnings("unused") boolean inReplacement) {
return SpecialsUtils.isCorrectUpdateSignature(signature) && arguments.length == 3 ? UpdateSubsetSpecialNodeGen.create(arguments) : null;
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) {
ProfiledValue vector = profile(args[0]);
ConvertIndex index = convertSubset(args[1]);
if (args.length == 3) {
return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]);
} else {
return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertSubset(args[2]), args[3]);
}
}
return null;
}
@Specialization(guards = "!args.isEmpty()")
......
......@@ -57,7 +57,8 @@ public abstract class ExtractListElement extends Node {
}
@Specialization
protected Object doList(RListBase list, int index, @Cached("create()") UpdateShareableChildValueNode updateStateNode) {
protected Object doList(RListBase list, int index,
@Cached("create()") UpdateShareableChildValueNode updateStateNode) {
Object element = list.getDataAt(index);
return updateStateNode.updateState(list, element);
}
......
......@@ -69,6 +69,10 @@ abstract class ReplacementNode extends OperatorNode {
// Note: if specials are turned off in FastR, onlySpecials will never be true
boolean createSpecial = hasOnlySpecialCalls(calls);
if (createSpecial) {
/*
* This assumes that whenever there's a special call for the "extract", there's also a
* special call for "replace".
*/
if (isVoid) {
return new SpecialVoidReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
} else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment