Skip to content
Snippets Groups Projects
Commit d69ab22e authored by Lukas Stadler's avatar Lukas Stadler
Browse files

refactor field access, add common base class

parent 3baf5336
Branches
No related tags found
No related merge requests found
/*
* Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.nodes.access;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.*;
import com.oracle.truffle.api.utilities.*;
import com.oracle.truffle.r.nodes.*;
import com.oracle.truffle.r.runtime.data.*;
@NodeChild(value = "object", type = RNode.class)
@NodeField(name = "field", type = String.class)
public abstract class AccessFieldBaseNode extends RNode {
public abstract RNode getObject();
public abstract String getField();
protected final ConditionProfile hasNamesProfile = ConditionProfile.createBinaryProfile();
protected final BranchProfile inexactMatch = BranchProfile.create();
protected final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
@TruffleBoundary
public static int getElementIndexByName(RStringVector names, String name) {
for (int i = 0; i < names.getLength(); ++i) {
if (names.getDataAt(i).equals(name)) {
return i;
}
}
return -1;
}
}
......@@ -23,7 +23,6 @@
package com.oracle.truffle.r.nodes.access;
import com.oracle.truffle.api.dsl.*;
import com.oracle.truffle.api.utilities.*;
import com.oracle.truffle.r.nodes.*;
import com.oracle.truffle.r.runtime.*;
import com.oracle.truffle.r.runtime.RDeparse.State;
......@@ -34,52 +33,43 @@ import com.oracle.truffle.r.runtime.env.*;
/**
* Perform a field access. This node represents the {@code $} operator in R.
*/
@NodeChild(value = "object", type = RNode.class)
@NodeField(name = "field", type = String.class)
public abstract class AccessFieldNode extends RNode {
public abstract RNode getObject();
public abstract String getField();
private final BranchProfile inexactMatch = BranchProfile.create();
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
public abstract class AccessFieldNode extends AccessFieldBaseNode {
@Specialization
protected RNull access(@SuppressWarnings("unused") RNull object) {
return RNull.instance;
}
@Specialization(guards = "hasNames(object)")
@Specialization
protected Object accessField(RList object) {
int index = object.getElementIndexByName(attrProfiles, getField());
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, getField());
RStringVector names = object.getNames(attrProfiles);
if (hasNamesProfile.profile(names != null)) {
int index = getElementIndexByName(names, getField());
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, getField());
}
return index == -1 ? RNull.instance : object.getDataAt(index);
} else {
return RNull.instance;
}
return index == -1 ? RNull.instance : object.getDataAt(index);
}
@Specialization(guards = "!hasNames(object)")
protected Object accessFieldNoNames(@SuppressWarnings("unused") RList object) {
return RNull.instance;
}
// TODO: this should ultimately be a generic function
@Specialization(guards = "hasNames(object)")
@Specialization
protected Object accessField(RDataFrame object) {
int index = object.getElementIndexByName(attrProfiles, getField());
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, getField());
// TODO: add warning if index found (disabled by default using options)
RStringVector names = object.getNames(attrProfiles);
if (hasNamesProfile.profile(names != null)) {
int index = getElementIndexByName(names, getField());
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, getField());
// TODO: add warning if index found (disabled by default using options)
}
return index == -1 ? RNull.instance : object.getDataAtAsObject(index);
} else {
return RNull.instance;
}
return index == -1 ? RNull.instance : object.getDataAtAsObject(index);
}
@Specialization(guards = "!hasNames(object)")
protected Object accessFieldNoNames(@SuppressWarnings("unused") RDataFrame object) {
return RNull.instance;
}
@Specialization
......@@ -93,25 +83,15 @@ public abstract class AccessFieldNode extends RNode {
throw RError.error(RError.Message.DOLLAR_ATOMIC_VECTORS);
}
@Specialization(guards = "hasNames(object)")
@Specialization
protected Object accessFieldHasNames(RLanguage object) {
String field = getField();
RStringVector names = object.getNames(attrProfiles);
for (int i = 0; i < names.getLength(); i++) {
if (field.equals(names.getDataAt(i))) {
return RContext.getRASTHelper().getDataAtAsObject(object, i);
}
if (hasNamesProfile.profile(names != null)) {
int index = getElementIndexByName(names, getField());
return index == -1 ? RNull.instance : RContext.getRASTHelper().getDataAtAsObject(object, index);
} else {
return RNull.instance;
}
return RNull.instance;
}
@Specialization(guards = "!hasNames(object)")
protected Object accessField(@SuppressWarnings("unused") RLanguage object) {
return RNull.instance;
}
protected boolean hasNames(RAbstractContainer object) {
return object.getNames(attrProfiles) != null;
}
@Override
......@@ -136,5 +116,4 @@ public abstract class AccessFieldNode extends RNode {
}
return AccessFieldNodeGen.create(object, field);
}
}
......@@ -37,32 +37,28 @@ import com.oracle.truffle.r.runtime.data.model.*;
import com.oracle.truffle.r.runtime.env.*;
import com.oracle.truffle.r.runtime.env.REnvironment.*;
@NodeChildren({@NodeChild(value = "object", type = RNode.class), @NodeChild(value = "value", type = RNode.class)})
@NodeField(name = "field", type = String.class)
public abstract class UpdateFieldNode extends RNode {
public abstract RNode getObject();
@NodeChild(value = "value", type = RNode.class)
public abstract class UpdateFieldNode extends AccessFieldBaseNode {
public abstract RNode getValue();
public abstract String getField();
private final BranchProfile inexactMatch = BranchProfile.create();
private final BranchProfile noRemoval = BranchProfile.create();
private final ConditionProfile nullValueProfile = ConditionProfile.createBinaryProfile();
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
@Child private CastListNode castList;
@Specialization(guards = "!isNull(value)")
protected Object updateField(RList object, Object value) {
RStringVector names = object.getNames(attrProfiles);
int index = -1;
String field = getField();
int index = object.getElementIndexByName(attrProfiles, field);
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, field);
if (hasNamesProfile.profile(names != null)) {
index = getElementIndexByName(names, field);
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, field);
}
}
int newLength = object.getLength() + (index == -1 ? 1 : 0);
if (index == -1) {
index = newLength - 1;
......@@ -76,7 +72,6 @@ public abstract class UpdateFieldNode extends RNode {
if (object.getNames(attrProfiles) == null) {
Arrays.fill(resultNames, "");
} else {
RStringVector names = object.getNames(attrProfiles);
System.arraycopy(names.getDataWithoutCopying(), 0, resultNames, 0, names.getLength());
namesComplete = names.isComplete();
}
......@@ -93,11 +88,15 @@ public abstract class UpdateFieldNode extends RNode {
@Specialization(guards = "isNull(value)")
protected Object updateFieldNullValue(RList object, @SuppressWarnings("unused") Object value) {
RStringVector names = object.getNames(attrProfiles);
int index = -1;
String field = getField();
int index = object.getElementIndexByName(attrProfiles, field);
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, field);
if (hasNamesProfile.profile(names != null)) {
index = getElementIndexByName(names, field);
if (index == -1) {
inexactMatch.enter();
index = object.getElementIndexByNameInexact(attrProfiles, field);
}
}
if (index == -1) {
......@@ -120,7 +119,6 @@ public abstract class UpdateFieldNode extends RNode {
if (object.getNames(attrProfiles) == null) {
Arrays.fill(resultNames, "");
} else {
RStringVector names = object.getNames(attrProfiles);
ind = 0;
for (int i = 0; i < names.getLength(); i++) {
if (i != index) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment