Skip to content
Snippets Groups Projects
Commit f479cf47 authored by Lukas Stadler's avatar Lukas Stadler Committed by stepan
Browse files

proper internal generic dispatch in as.vector

parent c465bdcd
Branches
No related tags found
No related merge requests found
...@@ -24,12 +24,21 @@ package com.oracle.truffle.r.nodes.builtin.base; ...@@ -24,12 +24,21 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.runtime.RBuiltinKind.INTERNAL; import static com.oracle.truffle.r.runtime.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.AsVectorNodeGen.AsVectorInternalNodeGen;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
import com.oracle.truffle.r.nodes.function.S3FunctionLookupNode;
import com.oracle.truffle.r.nodes.function.UseMethodInternalNode;
import com.oracle.truffle.r.nodes.unary.CastComplexNode; import com.oracle.truffle.r.nodes.unary.CastComplexNode;
import com.oracle.truffle.r.nodes.unary.CastDoubleNode; import com.oracle.truffle.r.nodes.unary.CastDoubleNode;
import com.oracle.truffle.r.nodes.unary.CastExpressionNode; import com.oracle.truffle.r.nodes.unary.CastExpressionNode;
...@@ -39,7 +48,9 @@ import com.oracle.truffle.r.nodes.unary.CastListNodeGen; ...@@ -39,7 +48,9 @@ import com.oracle.truffle.r.nodes.unary.CastListNodeGen;
import com.oracle.truffle.r.nodes.unary.CastLogicalNode; import com.oracle.truffle.r.nodes.unary.CastLogicalNode;
import com.oracle.truffle.r.nodes.unary.CastRawNode; import com.oracle.truffle.r.nodes.unary.CastRawNode;
import com.oracle.truffle.r.nodes.unary.CastSymbolNode; import com.oracle.truffle.r.nodes.unary.CastSymbolNode;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.RBuiltin; import com.oracle.truffle.r.runtime.RBuiltin;
import com.oracle.truffle.r.runtime.RDispatch;
import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.RType;
...@@ -52,215 +63,227 @@ import com.oracle.truffle.r.runtime.data.RIntVector; ...@@ -52,215 +63,227 @@ import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RInteger; import com.oracle.truffle.r.runtime.data.RInteger;
import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RLogical; import com.oracle.truffle.r.runtime.data.RLogical;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RRaw; import com.oracle.truffle.r.runtime.data.RRaw;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.RSymbol; import com.oracle.truffle.r.runtime.data.RSymbol;
import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
@RBuiltin(name = "as.vector", kind = INTERNAL, parameterNames = {"x", "mode"}) @RBuiltin(name = "as.vector", kind = INTERNAL, parameterNames = {"x", "mode"}, dispatch = RDispatch.INTERNAL_GENERIC)
public abstract class AsVector extends RBuiltinNode { public abstract class AsVector extends RBuiltinNode {
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); @Child private AsVectorInternal internal = AsVectorInternalNodeGen.create();
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
@Child private UseMethodInternalNode useMethod;
private final ConditionProfile hasClassProfile = ConditionProfile.createBinaryProfile();
@Override @Override
protected void createCasts(CastBuilder casts) { protected void createCasts(CastBuilder casts) {
casts.firstStringWithError(1, RError.Message.INVALID_ARGUMENT, "mode"); casts.firstStringWithError(1, RError.Message.INVALID_ARGUMENT, "mode");
} }
@Specialization protected static AsVectorInternal createInternal() {
protected Object asVector(RNull x, @SuppressWarnings("unused") RMissing mode) { return AsVectorInternalNodeGen.create();
controlVisibility();
return x;
} }
@Specialization(guards = "castToString(mode)") private static final ArgumentsSignature SIGNATURE = ArgumentsSignature.get("x", "mode");
protected Object asVectorString(Object x, @SuppressWarnings("unused") String mode, //
@Cached("create()") AsCharacter asCharacter) {
controlVisibility();
return asCharacter.execute(x);
}
@Specialization(guards = "castToInt(x, mode)") @Specialization
protected Object asVectorInt(RAbstractContainer x, @SuppressWarnings("unused") String mode, // protected Object asVector(VirtualFrame frame, Object x, String mode) {
@Cached("createNonPreserving()") CastIntegerNode cast) {
controlVisibility(); controlVisibility();
return cast.execute(x); RStringVector clazz = classHierarchy.execute(x);
if (hasClassProfile.profile(clazz != null)) {
if (useMethod == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
useMethod = insert(new UseMethodInternalNode("as.vector", SIGNATURE, false));
}
try {
return useMethod.execute(frame, clazz, new Object[]{x, mode});
} catch (S3FunctionLookupNode.NoGenericMethodException e) {
// fallthrough
}
}
return internal.execute(x, mode);
} }
@Specialization(guards = "castToDouble(x, mode)") public abstract static class AsVectorInternal extends Node {
protected Object asVectorDouble(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
@Cached("createNonPreserving()") CastDoubleNode cast) {
controlVisibility();
return cast.execute(x);
}
@Specialization(guards = "castToComplex(x, mode)") public abstract Object execute(Object x, String mode);
protected Object asVectorComplex(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
@Cached("createNonPreserving()") CastComplexNode cast) {
controlVisibility();
return cast.execute(x);
}
@Specialization(guards = "castToLogical(x, mode)") private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
protected Object asVectorLogical(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
@Cached("createNonPreserving()") CastLogicalNode cast) {
controlVisibility();
return cast.execute(x);
}
@Specialization(guards = "castToRaw(x, mode)") @Specialization(guards = "castToString(mode)")
protected Object asVectorRaw(RAbstractContainer x, @SuppressWarnings("unused") String mode, // protected Object asVectorString(Object x, @SuppressWarnings("unused") String mode, //
@Cached("createNonPreserving()") CastRawNode cast) { @Cached("create()") AsCharacter asCharacter) {
controlVisibility(); return asCharacter.execute(x);
return cast.execute(x); }
}
protected static CastListNode createListCast() { @Specialization(guards = "castToInt(x, mode)")
return CastListNodeGen.create(true, false, false); protected Object asVectorInt(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
} @Cached("createNonPreserving()") CastIntegerNode cast) {
return cast.execute(x);
}
@Specialization(guards = "castToList(mode)") @Specialization(guards = "castToDouble(x, mode)")
protected Object asVectorList(RAbstractContainer x, @SuppressWarnings("unused") String mode, // protected Object asVectorDouble(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
@Cached("createListCast()") CastListNode cast) { @Cached("createNonPreserving()") CastDoubleNode cast) {
controlVisibility(); return cast.execute(x);
return cast.execute(x); }
}
@Specialization(guards = "castToSymbol(x, mode)") @Specialization(guards = "castToComplex(x, mode)")
protected Object asVectorSymbol(RAbstractContainer x, @SuppressWarnings("unused") String mode, // protected Object asVectorComplex(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
@Cached("createNonPreserving()") CastSymbolNode cast) { @Cached("createNonPreserving()") CastComplexNode cast) {
controlVisibility(); return cast.execute(x);
return cast.execute(x); }
}
@Specialization(guards = "castToExpression(mode)") @Specialization(guards = "castToLogical(x, mode)")
protected Object asVectorExpression(Object x, @SuppressWarnings("unused") String mode, // protected Object asVectorLogical(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
@Cached("createNonPreserving()") CastExpressionNode cast) { @Cached("createNonPreserving()") CastLogicalNode cast) {
controlVisibility(); return cast.execute(x);
return cast.execute(x); }
}
@Specialization(guards = "castToList(mode)") @Specialization(guards = "castToRaw(x, mode)")
protected RAbstractVector asVectorList(@SuppressWarnings("unused") RNull x, @SuppressWarnings("unused") String mode) { protected Object asVectorRaw(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
controlVisibility(); @Cached("createNonPreserving()") CastRawNode cast) {
return RDataFactory.createList(); return cast.execute(x);
} }
@Specialization(guards = "isSymbol(x, mode)") protected static CastListNode createListCast() {
protected RSymbol asVectorSymbol(RSymbol x, @SuppressWarnings("unused") String mode) { return CastListNodeGen.create(true, false, false);
controlVisibility(); }
String sName = x.getName();
return RDataFactory.createSymbol(sName);
}
protected boolean isSymbol(@SuppressWarnings("unused") RSymbol x, String mode) { @Specialization(guards = "castToList(mode)")
return RType.Symbol.getName().equals(mode); protected Object asVectorList(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
} @Cached("createListCast()") CastListNode cast) {
return cast.execute(x);
}
@Specialization(guards = "modeIsAny(mode)") @Specialization(guards = "castToSymbol(x, mode)")
protected RAbstractVector asVector(RList x, @SuppressWarnings("unused") String mode) { protected Object asVectorSymbol(RAbstractContainer x, @SuppressWarnings("unused") String mode, //
controlVisibility(); @Cached("createNonPreserving()") CastSymbolNode cast) {
RList result = x.copyWithNewDimensions(null); return cast.execute(x);
result.copyNamesFrom(attrProfiles, x); }
return result;
}
@Specialization(guards = "modeIsAny(mode)") @Specialization(guards = "castToExpression(mode)")
protected RAbstractVector asVector(RFactor x, @SuppressWarnings("unused") String mode) { protected Object asVectorExpression(Object x, @SuppressWarnings("unused") String mode, //
RVector levels = x.getLevels(attrProfiles); @Cached("createNonPreserving()") CastExpressionNode cast) {
RVector result = levels.createEmptySameType(x.getLength(), RDataFactory.COMPLETE_VECTOR); return cast.execute(x);
RIntVector factorData = x.getVector();
for (int i = 0; i < result.getLength(); i++) {
result.transferElementSameType(i, levels, factorData.getDataAt(i) - 1);
} }
return result;
}
@Specialization(guards = "modeIsAny(mode)") @Specialization(guards = "castToList(mode)")
protected RNull asVector(RNull x, @SuppressWarnings("unused") String mode) { protected RAbstractVector asVectorList(@SuppressWarnings("unused") RNull x, @SuppressWarnings("unused") String mode) {
controlVisibility(); return RDataFactory.createList();
return x; }
}
@Specialization(guards = "modeIsPairList(mode)") @Specialization(guards = "isSymbol(x, mode)")
protected Object asVectorPairList(RList x, @SuppressWarnings("unused") String mode) { protected RSymbol asVectorSymbol(RSymbol x, @SuppressWarnings("unused") String mode) {
controlVisibility(); String sName = x.getName();
// TODO implement non-empty element list conversion; this is a placeholder for type test return RDataFactory.createSymbol(sName);
if (x.getLength() == 0) {
return RNull.instance;
} else {
throw RError.nyi(this, "non-empty lists");
} }
}
@Specialization(guards = "modeIsAny(mode)") protected boolean isSymbol(@SuppressWarnings("unused") RSymbol x, String mode) {
protected RAbstractVector asVectorAny(RAbstractVector x, @SuppressWarnings("unused") String mode) { return RType.Symbol.getName().equals(mode);
controlVisibility(); }
return x.copyWithNewDimensions(null);
}
@Specialization(guards = "modeMatches(x, mode)") @Specialization(guards = "modeIsAny(mode)")
protected RAbstractVector asVector(RAbstractVector x, @SuppressWarnings("unused") String mode) { protected RAbstractVector asVector(RList x, @SuppressWarnings("unused") String mode) {
controlVisibility(); RList result = x.copyWithNewDimensions(null);
return x.copyWithNewDimensions(null); result.copyNamesFrom(attrProfiles, x);
} return result;
}
protected boolean castToInt(RAbstractContainer x, String mode) { @Specialization(guards = "modeIsAny(mode)")
return x.getElementClass() != RInteger.class && RType.Integer.getName().equals(mode); protected RAbstractVector asVector(RFactor x, @SuppressWarnings("unused") String mode) {
} RVector levels = x.getLevels(attrProfiles);
RVector result = levels.createEmptySameType(x.getLength(), RDataFactory.COMPLETE_VECTOR);
RIntVector factorData = x.getVector();
for (int i = 0; i < result.getLength(); i++) {
result.transferElementSameType(i, levels, factorData.getDataAt(i) - 1);
}
return result;
}
protected boolean castToDouble(RAbstractContainer x, String mode) { @Specialization(guards = "modeIsAny(mode)")
return x.getElementClass() != RDouble.class && (RType.Double.getClazz().equals(mode) || RType.Double.getName().equals(mode)); protected RNull asVector(RNull x, @SuppressWarnings("unused") String mode) {
} return x;
}
protected boolean castToComplex(RAbstractContainer x, String mode) { @Specialization(guards = "modeIsPairList(mode)")
return x.getElementClass() != RComplex.class && RType.Complex.getName().equals(mode); protected Object asVectorPairList(RList x, @SuppressWarnings("unused") String mode) {
} // TODO implement non-empty element list conversion; this is a placeholder for type test
if (x.getLength() == 0) {
return RNull.instance;
} else {
throw RError.nyi(RError.SHOW_CALLER, "non-empty lists");
}
}
protected boolean castToLogical(RAbstractContainer x, String mode) { @Specialization(guards = "modeIsAny(mode)")
return x.getElementClass() != RLogical.class && RType.Logical.getName().equals(mode); protected RAbstractVector asVectorAny(RAbstractVector x, @SuppressWarnings("unused") String mode) {
} return x.copyWithNewDimensions(null);
}
protected boolean castToString(String mode) { @Specialization(guards = "modeMatches(x, mode)")
return RType.Character.getName().equals(mode); protected RAbstractVector asVector(RAbstractVector x, @SuppressWarnings("unused") String mode) {
} return x.copyWithNewDimensions(null);
}
protected boolean castToRaw(RAbstractContainer x, String mode) { protected boolean castToInt(RAbstractContainer x, String mode) {
return x.getElementClass() != RRaw.class && RType.Raw.getName().equals(mode); return x.getElementClass() != RInteger.class && RType.Integer.getName().equals(mode);
} }
protected boolean castToList(String mode) { protected boolean castToDouble(RAbstractContainer x, String mode) {
return RType.List.getName().equals(mode); return x.getElementClass() != RDouble.class && (RType.Double.getClazz().equals(mode) || RType.Double.getName().equals(mode));
} }
protected boolean castToSymbol(RAbstractContainer x, String mode) { protected boolean castToComplex(RAbstractContainer x, String mode) {
return x.getElementClass() != Object.class && RType.Symbol.getName().equals(mode); return x.getElementClass() != RComplex.class && RType.Complex.getName().equals(mode);
} }
protected boolean castToExpression(String mode) { protected boolean castToLogical(RAbstractContainer x, String mode) {
return RType.Expression.getName().equals(mode); return x.getElementClass() != RLogical.class && RType.Logical.getName().equals(mode);
} }
protected boolean modeMatches(RAbstractVector x, String mode) { protected boolean castToString(String mode) {
return RRuntime.classToString(x.getElementClass()).equals(mode) || x.getElementClass() == RDouble.class && RType.Double.getName().equals(mode); return RType.Character.getName().equals(mode);
} }
protected boolean modeIsAny(String mode) { protected boolean castToRaw(RAbstractContainer x, String mode) {
return RType.Any.getName().equals(mode); return x.getElementClass() != RRaw.class && RType.Raw.getName().equals(mode);
} }
protected boolean modeIsPairList(String mode) { protected boolean castToList(String mode) {
return RType.PairList.getName().equals(mode); return RType.List.getName().equals(mode);
} }
@SuppressWarnings("unused") protected boolean castToSymbol(RAbstractContainer x, String mode) {
@Fallback return x.getElementClass() != Object.class && RType.Symbol.getName().equals(mode);
@TruffleBoundary }
protected RAbstractVector asVectorWrongMode(Object x, Object mode) {
controlVisibility(); protected boolean castToExpression(String mode) {
throw RError.error(RError.SHOW_CALLER, RError.Message.INVALID_ARGUMENT, "mode"); return RType.Expression.getName().equals(mode);
}
protected boolean modeMatches(RAbstractVector x, String mode) {
return RRuntime.classToString(x.getElementClass()).equals(mode) || x.getElementClass() == RDouble.class && RType.Double.getName().equals(mode);
}
protected boolean modeIsAny(String mode) {
return RType.Any.getName().equals(mode);
}
protected boolean modeIsPairList(String mode) {
return RType.PairList.getName().equals(mode);
}
@SuppressWarnings("unused")
@Fallback
@TruffleBoundary
protected RAbstractVector asVectorWrongMode(Object x, String mode) {
throw RError.error(RError.SHOW_CALLER, RError.Message.INVALID_ARGUMENT, "mode");
}
} }
} }
...@@ -41,7 +41,6 @@ public abstract class DimNames extends RBuiltinNode { ...@@ -41,7 +41,6 @@ public abstract class DimNames extends RBuiltinNode {
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
private final ConditionProfile nullProfile = ConditionProfile.createBinaryProfile(); private final ConditionProfile nullProfile = ConditionProfile.createBinaryProfile();
private final BranchProfile dataframeProfile = BranchProfile.create();
private final BranchProfile factorProfile = BranchProfile.create(); private final BranchProfile factorProfile = BranchProfile.create();
private final BranchProfile otherProfile = BranchProfile.create(); private final BranchProfile otherProfile = BranchProfile.create();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment