Skip to content
Snippets Groups Projects
Commit 516ec369 authored by stepan's avatar stepan
Browse files

StatsFunctions: support copying of attributes to the result

parent f96fae32
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,8 @@ import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.attributes.UnaryCopyAttributesNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
import com.oracle.truffle.r.runtime.RError;
......@@ -49,8 +51,22 @@ public final class StatsFunctions {
double evaluate(double a, double b, double c, boolean x);
}
static final class StatFunctionProfiles {
final BranchProfile nan = BranchProfile.create();
final NACheck aCheck = NACheck.create();
final NACheck bCheck = NACheck.create();
final NACheck cCheck = NACheck.create();
final ConditionProfile copyAttrsFromA = ConditionProfile.createBinaryProfile();
final ConditionProfile copyAttrsFromB = ConditionProfile.createBinaryProfile();
final ConditionProfile copyAttrsFromC = ConditionProfile.createBinaryProfile();
public static StatFunctionProfiles create() {
return new StatFunctionProfiles();
}
}
private static RAbstractDoubleVector evaluate3(Node node, Function3_2 function, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y,
BranchProfile nan, NACheck aCheck, NACheck bCheck, NACheck cCheck) {
StatFunctionProfiles profiles, UnaryCopyAttributesNode copyAttributesNode) {
int aLength = a.getLength();
int bLength = b.getLength();
int cLength = c.getLength();
......@@ -63,17 +79,17 @@ public final class StatsFunctions {
boolean complete = true;
boolean nans = false;
aCheck.enable(a);
bCheck.enable(b);
cCheck.enable(c);
profiles.aCheck.enable(a);
profiles.bCheck.enable(b);
profiles.cCheck.enable(c);
for (int i = 0; i < length; i++) {
double aValue = a.getDataAt(i % aLength);
double bValue = b.getDataAt(i % bLength);
double cValue = c.getDataAt(i % cLength);
double value;
if (Double.isNaN(aValue) || Double.isNaN(bValue) || Double.isNaN(cValue)) {
nan.enter();
if (aCheck.check(aValue) || bCheck.check(bValue) || cCheck.check(cValue)) {
profiles.nan.enter();
if (profiles.aCheck.check(aValue) || profiles.bCheck.check(bValue) || profiles.cCheck.check(cValue)) {
value = RRuntime.DOUBLE_NA;
complete = false;
} else {
......@@ -82,7 +98,7 @@ public final class StatsFunctions {
} else {
value = function.evaluate(aValue, bValue, cValue, x, y);
if (Double.isNaN(value)) {
nan.enter();
profiles.nan.enter();
nans = true;
}
}
......@@ -91,7 +107,18 @@ public final class StatsFunctions {
if (nans) {
RError.warning(RError.SHOW_CALLER, RError.Message.NAN_PRODUCED);
}
return RDataFactory.createDoubleVector(result, complete);
RDoubleVector resultVec = RDataFactory.createDoubleVector(result, complete);
// copy attributes if necessary:
if (profiles.copyAttrsFromA.profile(aLength == length)) {
copyAttributesNode.execute(resultVec, a);
} else if (profiles.copyAttrsFromB.profile(bLength == length)) {
copyAttributesNode.execute(resultVec, b);
} else if (profiles.copyAttrsFromC.profile(cLength == length)) {
copyAttributesNode.execute(resultVec, c);
}
return resultVec;
}
public abstract static class Function3_2Node extends RExternalBuiltinNode.Arg5 {
......@@ -111,12 +138,10 @@ public final class StatsFunctions {
}
@Specialization
protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y, //
@Cached("create()") BranchProfile nan, //
@Cached("create()") NACheck aCheck, //
@Cached("create()") NACheck bCheck, //
@Cached("create()") NACheck cCheck) {
return evaluate3(this, function, a, b, c, x, y, nan, aCheck, bCheck, cCheck);
protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y,
@Cached("create()") StatFunctionProfiles profiles,
@Cached("create()") UnaryCopyAttributesNode copyAttributesNode) {
return evaluate3(this, function, a, b, c, x, y, profiles, copyAttributesNode);
}
}
......@@ -136,12 +161,10 @@ public final class StatsFunctions {
}
@Specialization
protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, //
@Cached("create()") BranchProfile nan, //
@Cached("create()") NACheck aCheck, //
@Cached("create()") NACheck bCheck, //
@Cached("create()") NACheck cCheck) {
return evaluate3(this, function, a, b, c, x, false /* dummy */, nan, aCheck, bCheck, cCheck);
protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x,
@Cached("create()") StatFunctionProfiles profiles,
@Cached("create()") UnaryCopyAttributesNode copyAttributesNode) {
return evaluate3(this, function, a, b, c, x, false /* dummy */, profiles, copyAttributesNode);
}
}
......
......@@ -49,6 +49,10 @@ public abstract class UnaryCopyAttributesNode extends RBaseNode {
this.copyAllAttributes = copyAllAttributes;
}
public static UnaryCopyAttributesNode create() {
return UnaryCopyAttributesNodeGen.create(true);
}
public abstract RAbstractVector execute(RAbstractVector target, RAbstractVector left);
protected boolean containsMetadata(RAbstractVector vector, RAttributeProfiles attrProfiles) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment