Skip to content
Snippets Groups Projects
Commit 915856a1 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

implement "mean" using VectorAccess

parent 82153817
No related branches found
No related tags found
No related merge requests found
...@@ -26,81 +26,79 @@ import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; ...@@ -26,81 +26,79 @@ import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_SUMMARY; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_SUMMARY;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RType;
import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
@ImportStatic(RType.class)
@RBuiltin(name = "mean", kind = INTERNAL, parameterNames = {"x"}, dispatch = INTERNAL_GENERIC, behavior = PURE_SUMMARY) @RBuiltin(name = "mean", kind = INTERNAL, parameterNames = {"x"}, dispatch = INTERNAL_GENERIC, behavior = PURE_SUMMARY)
public abstract class Mean extends RBuiltinNode.Arg1 { public abstract class Mean extends RBuiltinNode.Arg1 {
private final BranchProfile emptyProfile = BranchProfile.create();
@Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation();
@Child private BinaryArithmetic div = BinaryArithmetic.DIV.createOperation();
static { static {
Casts.noCasts(Mean.class); Casts.noCasts(Mean.class);
} }
@Specialization @Specialization(guards = {"access.supports(x)", "access.getType() != Complex"})
protected double mean(RAbstractDoubleVector x) { protected double meanDoubleCached(RAbstractVector x,
if (x.getLength() == 0) { @Cached("x.access()") VectorAccess access,
emptyProfile.enter(); @Cached("createBinaryProfile()") ConditionProfile emptyProfile) {
return Double.NaN; try (SequentialIterator iter = access.access(x)) {
if (emptyProfile.profile(!access.next(iter))) {
return Double.NaN;
}
double sum = 0;
do {
double value = access.getDouble(iter);
if (access.na.checkNAorNaN(value)) {
return value;
}
sum += value;
} while (access.next(iter));
return sum / access.getLength(iter);
} }
double sum = x.getDataAt(0);
for (int k = 1; k < x.getLength(); k++) {
sum = add.op(sum, x.getDataAt(k));
}
return div.op(sum, x.getLength());
} }
@Specialization @Specialization(replaces = "meanDoubleCached", guards = "x.getRType() != Complex")
protected double mean(RAbstractIntVector x) { protected double meanDoubleGeneric(RAbstractVector x,
if (x.getLength() == 0) { @Cached("createBinaryProfile()") ConditionProfile emptyProfile) {
emptyProfile.enter(); return meanDoubleCached(x, x.slowPathAccess(), emptyProfile);
return Double.NaN;
}
double sum = x.getDataAt(0);
for (int k = 1; k < x.getLength(); k++) {
sum = add.op(sum, x.getDataAt(k));
}
return div.op(sum, x.getLength());
} }
@Specialization @Specialization(guards = {"access.supports(x)", "access.getType() == Complex"})
protected double mean(RAbstractLogicalVector x) { protected RComplex meanComplexCached(RAbstractVector x,
if (x.getLength() == 0) { @Cached("x.access()") VectorAccess access,
emptyProfile.enter(); @Cached("createBinaryProfile()") ConditionProfile emptyProfile) {
return Double.NaN; try (SequentialIterator iter = access.access(x)) {
} if (emptyProfile.profile(!access.next(iter))) {
double sum = x.getDataAt(0); return RComplex.valueOf(Double.NaN, Double.NaN);
for (int k = 1; k < x.getLength(); k++) { }
sum = add.op(sum, x.getDataAt(k)); double sumR = 0;
double sumI = 0;
do {
double valueR = access.getComplexR(iter);
double valueI = access.getComplexI(iter);
if (access.na.check(valueR, valueI)) {
return RComplex.valueOf(valueR, valueI);
}
sumR += valueR;
sumI += valueI;
} while (access.next(iter));
int length = access.getLength(iter);
return RComplex.valueOf(sumR / length, sumI / length);
} }
return div.op(sum, x.getLength());
} }
@Specialization @Specialization(replaces = "meanComplexCached", guards = "x.getRType() == Complex")
protected RComplex mean(RAbstractComplexVector x) { protected RComplex meanComplexGeneric(RAbstractVector x,
if (x.getLength() == 0) { @Cached("createBinaryProfile()") ConditionProfile emptyProfile) {
emptyProfile.enter(); return meanComplexCached(x, x.slowPathAccess(), emptyProfile);
return RDataFactory.createComplex(Double.NaN, Double.NaN);
}
RComplex sum = x.getDataAt(0);
RComplex comp;
for (int k = 1; k < x.getLength(); k++) {
comp = x.getDataAt(k);
sum = add.op(sum.getRealPart(), sum.getImaginaryPart(), comp.getRealPart(), comp.getImaginaryPart());
}
return div.op(sum.getRealPart(), sum.getImaginaryPart(), x.getLength(), 0);
} }
} }
...@@ -39127,7 +39127,7 @@ integer(0) ...@@ -39127,7 +39127,7 @@ integer(0)
#argv <- structure(list(x = structure(c(31, NA, NA, 31), units = 'days', class = 'difftime'), na.rm = TRUE), .Names = c('x', 'na.rm'));do.call('mean', argv) #argv <- structure(list(x = structure(c(31, NA, NA, 31), units = 'days', class = 'difftime'), na.rm = TRUE), .Names = c('x', 'na.rm'));do.call('mean', argv)
Time difference of 31 days Time difference of 31 days
   
##com.oracle.truffle.r.test.builtins.TestBuiltin_mean.testmean2#Ignored.ImplementationError# ##com.oracle.truffle.r.test.builtins.TestBuiltin_mean.testmean2#
#argv <- list(c(0.104166666666667, 0.285714285714286, 0.285714285714286, NA)); .Internal(mean(argv[[1]])) #argv <- list(c(0.104166666666667, 0.285714285714286, 0.285714285714286, NA)); .Internal(mean(argv[[1]]))
[1] NA [1] NA
   
...@@ -24,11 +24,7 @@ public class TestBuiltin_mean extends TestBase { ...@@ -24,11 +24,7 @@ public class TestBuiltin_mean extends TestBase {
@Test @Test
public void testmean2() { public void testmean2() {
// FIXME NA is returned by GnuR for NA input assertEval("argv <- list(c(0.104166666666667, 0.285714285714286, 0.285714285714286, NA)); .Internal(mean(argv[[1]]))");
// Expected output: [1] NA
// FastR output: [1] NaN
assertEval(Ignored.ImplementationError, "argv <- list(c(0.104166666666667, 0.285714285714286, 0.285714285714286, NA)); .Internal(mean(argv[[1]]))");
} }
@Test @Test
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment