Skip to content
Snippets Groups Projects
Commit 8ccb427e authored by Lukas Stadler's avatar Lukas Stadler
Browse files

Merge pull request #216 in G/fastr from...

Merge pull request #216 in G/fastr from ~LUKAS.STADLER_ORACLE.COM/fastr:bugfix/conditional_map_node to master

* commit '457b9cff':
  work on nchar
  additional helpers in CastBuilder.Predef, bugfix in ConditionalMapNode
parents 2971342d 457b9cff
Branches
No related tags found
No related merge requests found
......@@ -24,86 +24,52 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asIntegerVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asStringVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.helpers.InheritsCheckNode;
import com.oracle.truffle.r.nodes.unary.CastStringNode;
import com.oracle.truffle.r.nodes.unary.CastStringNodeGen;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RAttributeProfiles;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
// TODO interpret "type" and "allowNA" arguments
@RBuiltin(name = "nchar", kind = INTERNAL, parameterNames = {"x", "type", "allowNA", "keepNA"}, behavior = PURE)
public abstract class NChar extends RBuiltinNode {
@Child private CastStringNode convertString;
@Child private InheritsCheckNode factorInheritsCheck;
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
public abstract Object execute(Object value, Object type, Object allowNA, Object keepNA);
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("x").mapIf(stringValue(), asStringVector(), asIntegerVector());
casts.toLogical(2).toLogical(3);
}
private String coerceContent(Object content) {
if (convertString == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
convertString = insert(CastStringNodeGen.create(false, false, false, false));
}
return (String) convertString.executeString(content);
casts.arg("x").mapIf(Predef.integerValue(), asIntegerVector(), asStringVector(true, false, false));
casts.arg("type").asStringVector().findFirst();
casts.arg("allowNA").asLogicalVector().findFirst(RRuntime.LOGICAL_TRUE).map(toBoolean());
casts.arg("keepNA").asLogicalVector().findFirst(RRuntime.LOGICAL_FALSE).map(toBoolean());
}
@SuppressWarnings("unused")
@Specialization
protected RIntVector nchar(RNull value, String type, byte allowNA, byte keepNA) {
protected RIntVector nchar(RNull value, String type, boolean allowNA, boolean keepNA) {
return RDataFactory.createEmptyIntVector();
}
@SuppressWarnings("unused")
@Specialization
protected int nchar(int value, String type, byte allowNA, byte keepNA) {
return coerceContent(value).length();
}
@SuppressWarnings("unused")
@Specialization
protected int nchar(double value, String type, byte allowNA, byte keepNA) {
return coerceContent(value).length();
}
@SuppressWarnings("unused")
@Specialization
protected int nchar(byte value, String type, byte allowNA, byte keepNA) {
return coerceContent(value).length();
}
@SuppressWarnings("unused")
@Specialization
protected RIntVector ncharInt(RAbstractIntVector vector, String type, byte allowNA, byte keepNA) {
protected RIntVector ncharInt(RAbstractIntVector vector, String type, boolean allowNA, boolean keepNA,
@Cached("createCountingProfile()") LoopConditionProfile loopProfile,
@Cached("create()") RAttributeProfiles attrProfiles) {
int len = vector.getLength();
int[] result = new int[len];
for (int i = 0; i < len; i++) {
loopProfile.profileCounted(len);
for (int i = 0; loopProfile.inject(i < len); i++) {
int x = vector.getDataAt(i);
if (x == RRuntime.INT_NA) {
result[i] = 2;
......@@ -111,68 +77,20 @@ public abstract class NChar extends RBuiltinNode {
result[i] = (int) (Math.log10(x) + 1); // not the fastest one
}
}
return RDataFactory.createIntVector(result, false);
return RDataFactory.createIntVector(result, true, vector.getNames(attrProfiles));
}
@SuppressWarnings("unused")
@Specialization(guards = "vector.getLength() == 0")
protected RIntVector ncharL0(RAbstractStringVector vector, String type, byte allowNA, byte keepNA) {
return RDataFactory.createEmptyIntVector();
}
@SuppressWarnings("unused")
@Specialization(guards = "vector.getLength() == 1")
protected int ncharL1(RAbstractStringVector vector, String type, byte allowNA, byte keepNA) {
return vector.getDataAt(0).length();
}
@SuppressWarnings("unused")
@Specialization(guards = "vector.getLength() > 1")
protected RIntVector nchar(RAbstractStringVector vector, String type, byte allowNA, byte keepNA) {
@Specialization
protected RIntVector nchar(RAbstractStringVector vector, String type, boolean allowNA, boolean keepNA,
@Cached("createCountingProfile()") LoopConditionProfile loopProfile,
@Cached("create()") RAttributeProfiles attrProfiles) {
int len = vector.getLength();
int[] result = new int[len];
for (int i = 0; i < len; i++) {
loopProfile.profileCounted(len);
for (int i = 0; loopProfile.inject(i < len); i++) {
result[i] = vector.getDataAt(i).length();
}
return RDataFactory.createIntVector(result, vector.isComplete(), vector.getNames(attrProfiles));
}
protected static NChar createRecursive() {
return NCharNodeGen.create(null);
}
/*
* this builtin is sometimes used with only 3 arguments - keepNA defaults to FALSE.
*/
@Specialization
protected Object ncharNoKeepNA(Object obj, Object type, Object allowNA, @SuppressWarnings("unused") RMissing keepNA, //
@Cached("createRecursive()") NChar rec) {
return rec.execute(obj, type, allowNA, RRuntime.LOGICAL_FALSE);
}
@SuppressWarnings("unused")
@Fallback
protected RIntVector nchar(Object obj, Object type, Object allowNA, Object keepNA) {
if (factorInheritsCheck == null) {
CompilerDirectives.transferToInterpreter();
factorInheritsCheck = insert(new InheritsCheckNode(RRuntime.CLASS_FACTOR));
}
if (factorInheritsCheck.execute(obj)) {
throw RError.error(this, RError.Message.REQUIRES_CHAR_VECTOR, "nchar");
}
if (obj instanceof RAbstractVector) {
RAbstractVector vector = (RAbstractVector) obj;
int len = vector.getLength();
int[] result = new int[len];
for (int i = 0; i < len; i++) {
result[i] = coerceContent(vector.getDataAtAsObject(i)).length();
}
return RDataFactory.createIntVector(result, vector.isComplete(), vector.getNames(attrProfiles));
} else {
throw RError.error(this, RError.Message.CANNOT_COERCE, RRuntime.classToString(obj.getClass()), "character");
}
return RDataFactory.createIntVector(result, true, vector.getNames(attrProfiles));
}
}
......@@ -635,6 +635,10 @@ public final class CastBuilder {
return phaseBuilder -> CastIntegerNodeGen.create(false, false, false);
}
public static <T> Function<ArgCastBuilder<T, ?>, CastNode> asIntegerVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) {
return phaseBuilder -> CastIntegerNodeGen.create(preserveNames, preserveDimensions, preserveAttributes);
}
public static <T> Function<ArgCastBuilder<T, ?>, CastNode> asDouble() {
return phaseBuilder -> CastDoubleBaseNodeGen.create(false, false, false);
}
......@@ -643,6 +647,10 @@ public final class CastBuilder {
return phaseBuilder -> CastDoubleNodeGen.create(false, false, false);
}
public static <T> Function<ArgCastBuilder<T, ?>, CastNode> asDoubleVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) {
return phaseBuilder -> CastDoubleNodeGen.create(preserveNames, preserveDimensions, preserveAttributes);
}
public static <T> Function<ArgCastBuilder<T, ?>, CastNode> asString() {
return phaseBuilder -> CastStringBaseNodeGen.create(false, false, false);
}
......@@ -651,6 +659,10 @@ public final class CastBuilder {
return phaseBuilder -> CastStringNodeGen.create(false, false, false, false);
}
public static <T> Function<ArgCastBuilder<T, ?>, CastNode> asStringVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) {
return phaseBuilder -> CastStringNodeGen.create(preserveNames, preserveDimensions, preserveAttributes, false);
}
public static <T> Function<ArgCastBuilder<T, ?>, CastNode> asLogical() {
return phaseBuilder -> CastLogicalBaseNodeGen.create(false, false, false);
}
......@@ -659,6 +671,10 @@ public final class CastBuilder {
return phaseBuilder -> CastLogicalNodeGen.create(false, false, false);
}
public static <T> Function<ArgCastBuilder<T, ?>, CastNode> asLogicalVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) {
return phaseBuilder -> CastLogicalNodeGen.create(preserveNames, preserveDimensions, preserveAttributes);
}
public static <T> FindFirstNodeBuilder<T> findFirst(RBaseNode callObj, RError.Message message, Object... messageArgs) {
return new FindFirstNodeBuilder<>(callObj, message, messageArgs);
}
......
......@@ -35,7 +35,7 @@ import com.oracle.truffle.r.runtime.nodes.RBaseNode;
public abstract class CastNode extends UnaryNode {
@TruffleBoundary
public static void handleArgumentError(Object arg, RBaseNode callObj, RError.Message message, Object[] messageArgs) {
protected static void handleArgumentError(Object arg, RBaseNode callObj, RError.Message message, Object[] messageArgs) {
if (RContext.getInstance() == null) {
throw new IllegalArgumentException(String.format(message.message, CastBuilder.substituteArgPlaceholder(arg, messageArgs)));
} else {
......@@ -44,7 +44,7 @@ public abstract class CastNode extends UnaryNode {
}
@TruffleBoundary
public static void handleArgumentWarning(Object arg, RBaseNode callObj, RError.Message message, Object[] messageArgs) {
protected static void handleArgumentWarning(Object arg, RBaseNode callObj, RError.Message message, Object[] messageArgs) {
if (message == null) {
return;
}
......@@ -56,5 +56,4 @@ public abstract class CastNode extends UnaryNode {
RError.warning(callObj, message, CastBuilder.substituteArgPlaceholder(arg, messageArgs));
}
}
}
......@@ -56,11 +56,11 @@ public abstract class ConditionalMapNode extends CastNode {
@Specialization(guards = "doMap(x)")
protected Object map(Object x) {
return trueBranch.execute(x);
return trueBranch == null ? x : trueBranch.execute(x);
}
@Specialization(guards = "!doMap(x)")
protected Object noMap(Object x) {
return x;
return falseBranch == null ? x : falseBranch.execute(x);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment