Skip to content
Snippets Groups Projects
Commit 448a41d8 authored by Florian Angerer's avatar Florian Angerer
Browse files

Restructured special subset/subscript implementation.

parent ffc57072
No related branches found
No related tags found
No related merge requests found
/*
* Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.nodes.builtin.base.infix;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertValue;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
public class ProfiledSpecialsUtils {
@NodeChild(value = "vector", type = RNode.class)
@NodeChild(value = "index", type = ConvertIndex.class)
protected abstract static class ProfiledSubscriptSpecialBase extends RNode {
protected static final int CACHE_LIMIT = 3;
protected final boolean inReplacement;
@Child protected SubscriptSpecialBase defaultAccessNode;
protected ProfiledSubscriptSpecialBase(boolean inReplacement) {
this.inReplacement = inReplacement;
}
protected SubscriptSpecialBase createAccessNode() {
throw RInternalError.shouldNotReachHere();
}
@Specialization(limit = "CACHE_LIMIT", guards = "vector.getClass() == clazz")
public Object access(VirtualFrame frame, RAbstractVector vector, Object index,
@Cached(value = "vector.getClass()") Class<?> clazz,
@Cached("createAccessNode()") SubscriptSpecialBase accessNodeCached) {
return accessNodeCached.execute(frame, clazz.cast(vector), index);
}
@Specialization(replaces = "access")
public Object accessGeneric(VirtualFrame frame, Object vector, Object index) {
if (defaultAccessNode == null) {
defaultAccessNode = insert(createAccessNode());
}
return defaultAccessNode.execute(frame, vector, index);
}
}
public abstract static class ProfiledSubscriptSpecial extends ProfiledSubscriptSpecialBase {
protected ProfiledSubscriptSpecial(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecialBase createAccessNode() {
return SubscriptSpecialNodeGen.create(inReplacement);
}
}
public abstract static class ProfiledSubsetSpecial extends ProfiledSubscriptSpecialBase {
protected ProfiledSubsetSpecial(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecialBase createAccessNode() {
return SubsetSpecialNodeGen.create(inReplacement);
}
}
@NodeChild(value = "vector", type = RNode.class)
@NodeChild(value = "index1", type = ConvertIndex.class)
@NodeChild(value = "index2", type = ConvertIndex.class)
public abstract static class ProfiledSubscriptSpecial2Base extends RNode {
protected static final int CACHE_LIMIT = 3;
protected final boolean inReplacement;
@Child protected SubscriptSpecial2Base defaultAccessNode;
protected ProfiledSubscriptSpecial2Base(boolean inReplacement) {
this.inReplacement = inReplacement;
}
protected SubscriptSpecial2Base createAccessNode() {
throw RInternalError.shouldNotReachHere();
}
@Specialization(limit = "CACHE_LIMIT", guards = "vector.getClass() == clazz")
public Object access(VirtualFrame frame, RAbstractVector vector, Object index1, Object index2,
@Cached("vector.getClass()") Class<?> clazz,
@Cached("createAccessNode()") SubscriptSpecial2Base accessNodeCached) {
return accessNodeCached.execute(frame, clazz.cast(vector), index1, index2);
}
@Specialization(replaces = "access")
public Object accessGeneric(VirtualFrame frame, Object vector, Object index1, Object index2) {
if (defaultAccessNode == null) {
defaultAccessNode = insert(createAccessNode());
}
return defaultAccessNode.execute(frame, vector, index1, index2);
}
}
public abstract static class ProfiledSubscriptSpecial2 extends ProfiledSubscriptSpecial2Base {
protected ProfiledSubscriptSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecial2Base createAccessNode() {
return SubscriptSpecial2NodeGen.create(inReplacement);
}
}
public abstract static class ProfiledSubsetSpecial2 extends ProfiledSubscriptSpecial2Base {
protected ProfiledSubsetSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecial2Base createAccessNode() {
return SubsetSpecial2NodeGen.create(inReplacement);
}
}
@NodeChild(value = "vector", type = RNode.class)
@NodeChild(value = "index", type = ConvertIndex.class)
@NodeChild(value = "value", type = ConvertValue.class)
public abstract static class ProfiledUpdateSubscriptSpecialBase extends RNode {
protected static final int CACHE_LIMIT = 3;
protected final boolean inReplacement;
public abstract Object execute(VirtualFrame frame, Object vector, Object index, Object value);
@Child protected UpdateSubscriptSpecial defaultAccessNode;
protected ProfiledUpdateSubscriptSpecialBase(boolean inReplacement) {
this.inReplacement = inReplacement;
}
protected UpdateSubscriptSpecial createAccessNode() {
return UpdateSubscriptSpecialNodeGen.create(inReplacement);
}
@Specialization(limit = "CACHE_LIMIT", guards = "vector.getClass() == clazz")
public Object access(VirtualFrame frame, Object vector, Object index, Object value,
@Cached("vector.getClass()") Class<?> clazz,
@Cached("createAccessNode()") UpdateSubscriptSpecial accessNodeCached) {
return accessNodeCached.execute(frame, clazz.cast(vector), index, value);
}
@Specialization(replaces = "access")
public Object accessGeneric(VirtualFrame frame, Object vector, Object index, Object value) {
if (defaultAccessNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
defaultAccessNode = insert(createAccessNode());
}
return defaultAccessNode.execute(frame, vector, index, value);
}
}
@NodeChild(value = "vector", type = RNode.class)
@NodeChild(value = "index1", type = ConvertIndex.class)
@NodeChild(value = "index2", type = ConvertIndex.class)
@NodeChild(value = "value", type = ConvertValue.class)
public abstract static class ProfiledUpdateSubscriptSpecial2 extends RNode {
protected static final int CACHE_LIMIT = 3;
protected final boolean inReplacement;
public abstract Object execute(VirtualFrame frame, Object vector, Object index1, Object index2, Object value);
@Child protected UpdateSubscriptSpecial2 defaultAccessNode;
protected ProfiledUpdateSubscriptSpecial2(boolean inReplacement) {
this.inReplacement = inReplacement;
}
protected UpdateSubscriptSpecial2 createAccessNode() {
return UpdateSubscriptSpecial2NodeGen.create(inReplacement);
}
@Specialization(limit = "CACHE_LIMIT", guards = "vector.getClass() == clazz")
public Object access(VirtualFrame frame, Object vector, Object index1, Object index2, Object value,
@Cached("vector.getClass()") Class<?> clazz,
@Cached("createAccessNode()") UpdateSubscriptSpecial2 accessNodeCached) {
return accessNodeCached.execute(frame, clazz.cast(vector), index1, index2, value);
}
@Specialization(replaces = "access")
public Object accessGeneric(VirtualFrame frame, Object vector, Object index1, Object index2, Object value) {
if (defaultAccessNode == null) {
defaultAccessNode = insert(createAccessNode());
}
return defaultAccessNode.execute(frame, vector, index1, index2, value);
}
}
}
......@@ -25,27 +25,23 @@ package com.oracle.truffle.r.nodes.builtin.base.infix;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.NodeCost;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ConvertIndexNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ConvertValueNodeGen;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
......@@ -67,155 +63,12 @@ class SpecialsUtils {
return false;
}
@NodeChild(value = "vector", type = RNode.class)
@NodeChild(value = "index", type = ConvertIndex.class)
protected abstract static class ProfiledSubscriptSpecialBase extends SubscriptSpecialCommon {
protected static final int CACHE_LIMIT = 3;
@Child protected SubscriptSpecialBase defaultAccessNode;
protected ProfiledSubscriptSpecialBase(boolean inReplacement) {
super(inReplacement);
}
protected SubscriptSpecialBase createAccessNode() {
throw RInternalError.shouldNotReachHere();
}
@Specialization(limit = "CACHE_LIMIT", guards = "vector.getClass() == clazz")
public Object access(VirtualFrame frame, RAbstractVector vector, int index, @Cached(value = "vector.getClass()") Class<?> clazz,
@Cached("createAccessNode()") SubscriptSpecialBase accessNodeCached) {
return accessNodeCached.execute(frame, clazz.cast(vector), index);
}
@Fallback
public Object accessGeneric(VirtualFrame frame, Object vector, Object index) {
if (defaultAccessNode == null) {
defaultAccessNode = insert(createAccessNode());
}
return defaultAccessNode.execute(frame, vector, index);
}
}
public abstract static class ProfiledSubscriptSpecial extends ProfiledSubscriptSpecialBase {
protected ProfiledSubscriptSpecial(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecialBase createAccessNode() {
return SubscriptSpecialNodeGen.create(inReplacement);
}
}
public abstract static class ProfiledSubsetSpecial extends ProfiledSubscriptSpecialBase {
@Child protected SubsetSpecial accessNode;
protected ProfiledSubsetSpecial(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecialBase createAccessNode() {
return SubsetSpecialNodeGen.create(inReplacement);
}
}
@NodeChild(value = "vector", type = RNode.class)
@NodeChild(value = "index1", type = ConvertIndex.class)
@NodeChild(value = "index2", type = ConvertIndex.class)
protected abstract static class ProfiledSubscriptSpecial2Base extends SubscriptSpecialCommon {
protected static final int CACHE_LIMIT = 3;
@Child protected SubscriptSpecial2Base defaultAccessNode;
protected ProfiledSubscriptSpecial2Base(boolean inReplacement) {
super(inReplacement);
}
protected SubscriptSpecial2Base createAccessNode() {
throw RInternalError.shouldNotReachHere();
}
@Specialization(limit = "CACHE_LIMIT", guards = "vector.getClass() == clazz")
public Object access(VirtualFrame frame, RAbstractVector vector, int index1, int index2, @Cached("vector.getClass()") Class<?> clazz,
@Cached("createAccessNode()") SubscriptSpecial2Base accessNodeCached) {
return accessNodeCached.execute(frame, clazz.cast(vector), index1, index2);
}
@Fallback
public Object accessGeneric(VirtualFrame frame, Object vector, Object index1, Object index2) {
if (defaultAccessNode == null) {
defaultAccessNode = insert(createAccessNode());
}
return defaultAccessNode.execute(frame, vector, index1, index2);
}
}
public abstract static class ProfiledSubscriptSpecial2 extends ProfiledSubscriptSpecial2Base {
@Child protected SubscriptSpecial2 accessNode;
protected ProfiledSubscriptSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecial2Base createAccessNode() {
return SubscriptSpecial2NodeGen.create(inReplacement);
}
}
public abstract static class ProfiledSubsetSpecial2 extends ProfiledSubscriptSpecial2Base {
@Child protected SubsetSpecial2 accessNode;
protected ProfiledSubsetSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Override
protected SubscriptSpecial2Base createAccessNode() {
return SubsetSpecial2NodeGen.create(inReplacement);
}
}
/**
* Common code shared between specials doing subset/subscript related operation.
*/
abstract static class SubscriptSpecialCommon1 extends Node {
protected final boolean inReplacement;
protected SubscriptSpecialCommon1(boolean inReplacement) {
this.inReplacement = inReplacement;
}
abstract static class SubscriptSpecialCommon extends Node {
/**
* Checks whether the given (1-based) index is valid for the given vector.
*/
protected static boolean isValidIndex(RAbstractVector vector, int index) {
return index >= 1 && index <= vector.getLength();
}
/**
* Checks if the value is single element that can be put into a list or vector as is,
* because in the case of vectors on the LSH of update we take each element and put it into
* the RHS of the update function.
*/
protected static boolean isSingleElement(Object value) {
return value instanceof Integer || value instanceof Double || value instanceof Byte || value instanceof String;
}
}
/**
* Common code shared between specials doing subset/subscript related operation.
*/
abstract static class SubscriptSpecialCommon extends RNode {
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
protected final boolean inReplacement;
......@@ -223,6 +76,10 @@ class SpecialsUtils {
this.inReplacement = inReplacement;
}
protected boolean simpleVector(RAbstractVector vector) {
return classHierarchy.execute(vector) == null;
}
/**
* Checks whether the given (1-based) index is valid for the given vector.
*/
......@@ -261,27 +118,6 @@ class SpecialsUtils {
}
}
abstract static class SubscriptSpecial2Common1 extends SubscriptSpecialCommon1 {
protected SubscriptSpecial2Common1(boolean inReplacement) {
super(inReplacement);
}
@Child private GetDimAttributeNode getDimensions = GetDimAttributeNode.create();
protected int matrixIndex(RAbstractVector vector, int index1, int index2) {
return index1 - 1 + ((index2 - 1) * getDimensions.getDimensions(vector)[0]);
}
/**
* Checks whether the given (1-based) indexes are valid for the given matrix.
*/
protected boolean isValidIndex(RAbstractVector vector, int index1, int index2) {
int[] dimensions = getDimensions.getDimensions(vector);
return dimensions != null && dimensions.length == 2 && index1 >= 1 && index1 <= dimensions[0] && index2 >= 1 && index2 <= dimensions[1];
}
}
/**
* Common code shared between specials accessing/updating fields.
*/
......@@ -318,27 +154,6 @@ class SpecialsUtils {
}
}
@NodeInfo(cost = NodeCost.NONE)
public static final class ProfiledValue extends RBaseNode {
private final ValueProfile profile = ValueProfile.createClassProfile();
@Child private RNode delegate;
protected ProfiledValue(RNode delegate) {
this.delegate = delegate;
}
public Object execute(VirtualFrame frame) {
return profile.profile(delegate.execute(frame));
}
@Override
protected RSyntaxNode getRSyntaxNode() {
return delegate.asRSyntaxNode();
}
}
@NodeInfo(cost = NodeCost.NONE)
@NodeChild(value = "delegate", type = RNode.class)
public abstract static class ConvertIndex extends RNode {
......@@ -417,10 +232,6 @@ class SpecialsUtils {
}
}
public static ProfiledValue profile(RNode value) {
return new ProfiledValue(value);
}
public static ConvertIndex convertIndex(RNode value) {
return ConvertIndexNodeGen.create(value);
}
......
......@@ -36,13 +36,11 @@ import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ExtractListElement;
import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.ProfiledSpecialsUtilsFactory.ProfiledSubscriptSpecial2NodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.ProfiledSpecialsUtilsFactory.ProfiledSubscriptSpecialNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecial2Common1;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecialCommon1;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ProfiledSubscriptSpecial2NodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ProfiledSubscriptSpecialNodeGen;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecial2Common;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecialCommon;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime;
......@@ -64,20 +62,14 @@ import com.oracle.truffle.r.runtime.nodes.RNode;
/**
* Subscript code for vectors minus list is the same as subset code, this class allows sharing it.
*/
abstract class SubscriptSpecialBase extends SubscriptSpecialCommon1 {
abstract class SubscriptSpecialBase extends SubscriptSpecialCommon {
protected SubscriptSpecialBase(boolean inReplacement) {
super(inReplacement);
}
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
protected abstract Object execute(VirtualFrame frame, Object vec, Object index);
protected boolean simpleVector(RAbstractVector vector) {
return classHierarchy.execute(vector) == null;
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index)"})
protected int access(RAbstractIntVector vector, int index) {
return vector.getDataAt(index - 1);
......@@ -103,20 +95,14 @@ abstract class SubscriptSpecialBase extends SubscriptSpecialCommon1 {
/**
* Subscript code for matrices minus list is the same as subset code, this class allows sharing it.
*/
abstract class SubscriptSpecial2Base extends SubscriptSpecial2Common1 {
abstract class SubscriptSpecial2Base extends SubscriptSpecial2Common {
protected SubscriptSpecial2Base(boolean inReplacement) {
super(inReplacement);
}
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
public abstract Object execute(VirtualFrame frame, Object vector, Object index1, Object index2);
protected boolean simpleVector(RAbstractVector vector) {
return classHierarchy.execute(vector) == null;
}
@Specialization(guards = {"simpleVector(vector)", "isValidIndex(vector, index1, index2)"})
protected int access(RAbstractIntVector vector, int index1, int index2) {
return vector.getDataAt(matrixIndex(vector, index1, index2));
......
......@@ -35,9 +35,9 @@ import com.oracle.truffle.r.nodes.access.vector.ExtractListElement;
import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.ProfiledSpecialsUtilsFactory.ProfiledSubsetSpecial2NodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.ProfiledSpecialsUtilsFactory.ProfiledSubsetSpecialNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ProfiledSubsetSpecial2NodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtilsFactory.ProfiledSubsetSpecialNodeGen;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
......
......@@ -24,7 +24,6 @@ package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertIndex;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertValue;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
......@@ -32,20 +31,18 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.Arrays;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.ProfiledSpecialsUtilsFactory.ProfiledUpdateSubscriptSpecial2NodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.ProfiledSpecialsUtilsFactory.ProfiledUpdateSubscriptSpecialBaseNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertValue;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecial2Common;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.SubscriptSpecialCommon;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
......@@ -59,45 +56,38 @@ import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
@NodeChild(value = "vector", type = ProfiledValue.class)
@NodeChild(value = "index", type = ConvertIndex.class)
@NodeChild(value = "value", type = ConvertValue.class)
abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon {
protected UpdateSubscriptSpecial(boolean inReplacement) {
super(inReplacement);
}
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
private final NACheck naCheck = NACheck.create();
protected boolean simple(Object vector) {
return classHierarchy.execute(vector) == null;
}
protected abstract Object execute(VirtualFrame frame, Object vec, Object index, Object value);
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
protected RIntVector set(RIntVector vector, int index, int value) {
return vector.updateDataAt(index - 1, value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
protected RDoubleVector set(RDoubleVector vector, int index, double value) {
return vector.updateDataAt(index - 1, value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
protected RStringVector set(RStringVector vector, int index, String value) {
return vector.updateDataAt(index - 1, value, naCheck);
}
@Specialization(guards = {"simple(list)", "!list.isShared()", "isValidIndex(list, index)", "isSingleElement(value)"})
@Specialization(guards = {"simpleVector(list)", "!list.isShared()", "isValidIndex(list, index)", "isSingleElement(value)"})
protected static Object set(RList list, int index, Object value) {
list.setDataAt(list.getInternalStore(), index - 1, value);
return list;
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index)"})
protected RDoubleVector setDoubleIntIndexIntValue(RDoubleVector vector, int index, int value) {
return vector.updateDataAt(index - 1, value, naCheck);
}
......@@ -107,48 +97,44 @@ abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon {
protected static Object setFallback(Object vector, Object index, Object value) {
throw RSpecialFactory.throwFullCallNeeded(value);
}
public static RNode create(boolean inReplacement, RNode vector, ConvertIndex index, ConvertValue value) {
return ProfiledUpdateSubscriptSpecialBaseNodeGen.create(inReplacement, vector, index, value);
}
}
@NodeChild(value = "vector", type = ProfiledValue.class)
@NodeChild(value = "index1", type = ConvertIndex.class)
@NodeChild(value = "index2", type = ConvertIndex.class)
@NodeChild(value = "value", type = ConvertValue.class)
abstract class UpdateSubscriptSpecial2 extends SubscriptSpecial2Common {
protected UpdateSubscriptSpecial2(boolean inReplacement) {
super(inReplacement);
}
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
private final NACheck naCheck = NACheck.create();
protected boolean simple(Object vector) {
return classHierarchy.execute(vector) == null;
}
protected abstract Object execute(VirtualFrame frame, Object vec, Object index1, Object index2, Object value);
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RIntVector set(RIntVector vector, int index1, int index2, int value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RDoubleVector set(RDoubleVector vector, int index1, int index2, double value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RStringVector set(RStringVector vector, int index1, int index2, String value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
@Specialization(guards = {"simple(list)", "!list.isShared()", "isValidIndex(list, index1, index2)", "isSingleElement(value)"})
@Specialization(guards = {"simpleVector(list)", "!list.isShared()", "isValidIndex(list, index1, index2)", "isSingleElement(value)"})
protected Object set(RList list, int index1, int index2, Object value) {
list.setDataAt(list.getInternalStore(), matrixIndex(list, index1, index2), value);
return list;
}
@Specialization(guards = {"simple(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
@Specialization(guards = {"simpleVector(vector)", "!vector.isShared()", "isValidIndex(vector, index1, index2)"})
protected RDoubleVector setDoubleIntIndexIntValue(RDoubleVector vector, int index1, int index2, int value) {
return vector.updateDataAt(matrixIndex(vector, index1, index2), value, naCheck);
}
......@@ -158,6 +144,10 @@ abstract class UpdateSubscriptSpecial2 extends SubscriptSpecial2Common {
protected static Object setFallback(Object vector, Object index1, Object index2, Object value) {
throw RSpecialFactory.throwFullCallNeeded(value);
}
public static RNode create(boolean inReplacement, RNode vector, ConvertIndex index1, ConvertIndex index2, ConvertValue value) {
return ProfiledUpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index1, index2, value);
}
}
@RBuiltin(name = "[[<-", kind = PRIMITIVE, parameterNames = {"", "..."}, dispatch = INTERNAL_GENERIC, behavior = PURE)
......@@ -173,12 +163,11 @@ public abstract class UpdateSubscript extends RBuiltinNode.Arg2 {
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) {
ProfiledValue vector = profile(args[0]);
ConvertIndex index = convertIndex(args[1]);
if (args.length == 3) {
return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, convertValue(args[2]));
return UpdateSubscriptSpecial.create(inReplacement, args[0], index, convertValue(args[2]));
} else {
return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertIndex(args[2]), convertValue(args[3]));
return UpdateSubscriptSpecial2.create(inReplacement, args[0], index, convertIndex(args[2]), convertValue(args[3]));
}
}
return null;
......
......@@ -24,7 +24,6 @@ package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertIndex;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertValue;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
......@@ -38,7 +37,6 @@ import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
import com.oracle.truffle.r.nodes.access.vector.ReplaceVectorNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ConvertIndex;
import com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.ProfiledValue;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
......@@ -58,12 +56,11 @@ public abstract class UpdateSubset extends RBuiltinNode.Arg1 {
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) {
ProfiledValue vector = profile(args[0]);
ConvertIndex index = convertIndex(args[1]);
if (args.length == 3) {
return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, convertValue(args[2]));
return UpdateSubscriptSpecial.create(inReplacement, args[0], index, convertValue(args[2]));
} else {
return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertIndex(args[2]), convertValue(args[3]));
return UpdateSubscriptSpecial2.create(inReplacement, args[0], index, convertIndex(args[2]), convertValue(args[3]));
}
}
return null;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment