Skip to content
Snippets Groups Projects
Commit c2fa3e62 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

fixes in LogFunctions, CastStringNode and Order

parent d9e18431
Branches
No related tags found
No related merge requests found
......@@ -28,9 +28,13 @@ import static com.oracle.truffle.r.runtime.RDispatch.MATH_GROUP_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.Arrays;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
......@@ -49,14 +53,10 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
import com.oracle.truffle.r.runtime.ops.na.NAProfile;
import java.util.Arrays;
import java.util.function.Function;
public class LogFunctions {
@RBuiltin(name = "log", kind = PRIMITIVE, parameterNames = {"x", "base"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE)
......@@ -158,103 +158,47 @@ public class LogFunctions {
return logb(x, base, divNode, naBase);
}
@Specialization
protected RDoubleVector log(RAbstractIntVector vector, double base,
@Specialization(guards = "!isRAbstractComplexVector(vector)")
protected RDoubleVector log(RAbstractVector vector, double base,
@Cached("createClassProfile()") ValueProfile vectorProfile,
@Cached("createBinaryProfile()") ConditionProfile isNAProfile,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, base, index -> xNACheck.convertIntToDouble(vector.getDataAt(index)), copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck);
}
@Specialization
protected RDoubleVector log(RAbstractDoubleVector vector, double base,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, base, index -> checkDouble(vector.getDataAt(index), xNACheck), copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck);
}
private static double checkDouble(double d, NACheck na) {
na.check(d);
return d;
}
@Specialization
protected RDoubleVector log(RAbstractLogicalVector vector, double base,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, base, index -> xNACheck.convertLogicalToDouble(vector.getDataAt(index)), copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck);
RAbstractDoubleVector doubleVector = (RAbstractDoubleVector) vectorProfile.profile(vector).castSafe(RType.Double, isNAProfile);
return logInternal(doubleVector, base, copyAttrsNode, getNamesNode, getDimsNode, xNACheck, baseNACheck);
}
@Specialization
protected RComplexVector log(RAbstractComplexVector vector, double base,
@Cached("createClassProfile()") ValueProfile vectorProfile,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, RComplex.valueOf(base, 0), copyAttrsNode, getNamesNode, getDimsNode, divNode, xNACheck, baseNACheck);
return logInternal(vectorProfile.profile(vector), RComplex.valueOf(base, 0), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck);
}
@Specialization
protected RAbstractComplexVector log(RAbstractIntVector vector, RComplex base,
protected RAbstractComplexVector log(RAbstractVector vector, RComplex base,
@Cached("createClassProfile()") ValueProfile vectorProfile,
@Cached("createBinaryProfile()") ConditionProfile isNAProfile,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, base, index -> xNACheck.convertIntToComplex(vector.getDataAt(index)), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck);
RAbstractComplexVector complexVector = (RAbstractComplexVector) vectorProfile.profile(vector).castSafe(RType.Complex, isNAProfile);
return logInternal(complexVector, base, divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck);
}
@Specialization
protected RAbstractComplexVector log(RAbstractDoubleVector vector, RComplex base,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, base, index -> xNACheck.convertDoubleToComplex(vector.getDataAt(index)), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck);
}
@Specialization
protected RAbstractComplexVector log(RAbstractLogicalVector vector, RComplex base,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, base, index -> xNACheck.convertLogicalToComplex(vector.getDataAt(index)), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck);
}
@Specialization
protected RComplexVector log(RAbstractComplexVector vector, RComplex base,
@Cached("create()") CopyOfRegAttributesNode copyAttrsNode,
@Cached("create()") GetNamesAttributeNode getNamesNode,
@Cached("create()") GetDimAttributeNode getDimsNode,
@Cached("createDivNode()") BinaryMapArithmeticFunctionNode divNode,
@Cached("create()") NACheck xNACheck,
@Cached("create()") NACheck baseNACheck) {
return log(vector, base, index -> checkComplex(vector.getDataAt(index), xNACheck), divNode, getDimsNode, getNamesNode, copyAttrsNode, xNACheck, baseNACheck);
}
private static RComplex checkComplex(RComplex rc, NACheck xNACheck) {
xNACheck.check(rc);
return rc;
}
private RDoubleVector log(RAbstractVector vector, double base, Function<Integer, Double> toDouble, CopyOfRegAttributesNode copyAttrsNode, GetNamesAttributeNode getNamesNode,
GetDimAttributeNode getDimsNode, NACheck xNACheck, NACheck baseNACheck) {
private RDoubleVector logInternal(RAbstractDoubleVector vector, double base, CopyOfRegAttributesNode copyAttrsNode, GetNamesAttributeNode getNamesNode, GetDimAttributeNode getDimsNode,
NACheck xNACheck, NACheck baseNACheck) {
baseNACheck.enable(base);
double[] resultVector = new double[vector.getLength()];
if (baseNACheck.check(base)) {
......@@ -266,12 +210,8 @@ public class LogFunctions {
xNACheck.enable(vector);
Runnable[] warningResult = new Runnable[1];
for (int i = 0; i < vector.getLength(); i++) {
double value = toDouble.apply(i);
if (!naX.isNA(value)) {
resultVector[i] = logb(value, base, warningResult);
} else {
resultVector[i] = value;
}
double value = vector.getDataAt(i);
resultVector[i] = xNACheck.check(value) ? RRuntime.DOUBLE_NA : logb(value, base, warningResult);
}
if (warningResult[0] != null) {
warningResult[0].run();
......@@ -315,8 +255,8 @@ public class LogFunctions {
return result;
}
private RComplexVector log(RAbstractVector vector, RComplex base, Function<Integer, RComplex> toComplex, BinaryMapArithmeticFunctionNode divNode, GetDimAttributeNode getDimsNode,
GetNamesAttributeNode getNamesNode, CopyOfRegAttributesNode copyAttrsNode, NACheck xNACheck, NACheck baseNACheck) {
private RComplexVector logInternal(RAbstractComplexVector vector, RComplex base, BinaryMapArithmeticFunctionNode divNode, GetDimAttributeNode getDimsNode, GetNamesAttributeNode getNamesNode,
CopyOfRegAttributesNode copyAttrsNode, NACheck xNACheck, NACheck baseNACheck) {
baseNACheck.enable(base);
double[] complexVector = new double[vector.getLength() * 2];
if (baseNACheck.check(base)) {
......@@ -328,13 +268,13 @@ public class LogFunctions {
xNACheck.enable(vector);
boolean seenNaN = false;
for (int i = 0; i < vector.getLength(); i++) {
RComplex value = toComplex.apply(i);
if (!naX.isNA(value)) {
RComplex value = vector.getDataAt(i);
if (xNACheck.check(value)) {
fill(complexVector, i * 2, value);
} else {
RComplex rc = logb(value, base, divNode, false);
seenNaN = isNaN(rc);
fill(complexVector, i * 2, rc);
} else {
fill(complexVector, i * 2, value);
}
}
if (seenNaN) {
......
......@@ -67,6 +67,7 @@ public abstract class Order extends RPrecedenceBuiltinNode {
private final BranchProfile error = BranchProfile.create();
private final ConditionProfile notRemoveNAs = ConditionProfile.createBinaryProfile();
private final ValueProfile vectorProfile = ValueProfile.createClassProfile();
/**
* For use by {@link RadixSort}.
......@@ -79,7 +80,8 @@ public abstract class Order extends RPrecedenceBuiltinNode {
return executeOrderVector1(v, naLast, dec, false);
}
private RIntVector executeOrderVector1(RAbstractVector v, byte naLast, boolean dec, boolean needsStringCollation) {
private RIntVector executeOrderVector1(RAbstractVector vIn, byte naLast, boolean dec, boolean needsStringCollation) {
RAbstractVector v = vectorProfile.profile(vIn);
int n = v.getLength();
reportWork(n);
......@@ -170,31 +172,27 @@ public abstract class Order extends RPrecedenceBuiltinNode {
@Specialization(guards = {"oneVec(args)", "isFirstIntegerPrecedence(args)"})
Object orderInt(byte naLast, boolean decreasing, RArgsValuesAndNames args) {
Object[] vectors = args.getArguments();
RAbstractIntVector v = (RAbstractIntVector) castVector(vectors[0]);
RAbstractIntVector v = (RAbstractIntVector) castVector(args.getArgument(0));
return executeOrderVector1(v, naLast, decreasing);
}
@Specialization(guards = {"oneVec(args)", "isFirstDoublePrecedence(args)"})
Object orderDouble(byte naLast, boolean decreasing, RArgsValuesAndNames args) {
Object[] vectors = args.getArguments();
RAbstractDoubleVector v = (RAbstractDoubleVector) castVector(vectors[0]);
RAbstractDoubleVector v = (RAbstractDoubleVector) castVector(args.getArgument(0));
return executeOrderVector1(v, naLast, decreasing);
}
@Specialization(guards = {"oneVec(args)", "isFirstLogicalPrecedence(args)"})
Object orderLogical(byte naLast, boolean decreasing, RArgsValuesAndNames args,
@Cached("createBinaryProfile()") ConditionProfile isNAProfile) {
Object[] vectors = args.getArguments();
RAbstractIntVector v = (RAbstractIntVector) castVector(vectors[0]).castSafe(RType.Integer, isNAProfile);
RAbstractIntVector v = (RAbstractIntVector) castVector(args.getArgument(0)).castSafe(RType.Integer, isNAProfile);
return executeOrderVector1(v, naLast, decreasing);
}
@Specialization(guards = {"oneVec(args)", "isFirstStringPrecedence(args)"})
Object orderString(byte naLast, boolean decreasing, RArgsValuesAndNames args,
@Cached("create()") BranchProfile collationProfile) {
Object[] vectors = args.getArguments();
RAbstractStringVector v = (RAbstractStringVector) castVector(vectors[0]);
RAbstractStringVector v = (RAbstractStringVector) castVector(args.getArgument(0));
int n = v.getLength();
boolean needsCollation = false;
outer: for (int i = 0; i < n; i++) {
......@@ -214,8 +212,7 @@ public abstract class Order extends RPrecedenceBuiltinNode {
@Specialization(guards = {"oneVec(args)", "isFirstComplexPrecedence( args)"})
Object orderComplex(byte naLast, boolean decreasing, RArgsValuesAndNames args) {
Object[] vectors = args.getArguments();
RAbstractComplexVector v = (RAbstractComplexVector) castVector(vectors[0]);
RAbstractComplexVector v = (RAbstractComplexVector) castVector(args.getArgument(0));
return executeOrderVector1(v, naLast, decreasing);
}
......
......@@ -25,6 +25,7 @@ package com.oracle.truffle.r.nodes.unary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.runtime.RDeparse;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RLanguage;
......@@ -65,8 +66,10 @@ public abstract class CastStringNode extends CastStringBaseNode {
}
@Specialization
protected RStringVector doAbstractContainer(RAbstractContainer operand,
protected RStringVector doAbstractContainer(RAbstractContainer operandIn,
@Cached("createClassProfile()") ValueProfile operandProfile,
@Cached("createBinaryProfile()") ConditionProfile isLanguageProfile) {
RAbstractContainer operand = operandProfile.profile(operandIn);
String[] sdata = new String[operand.getLength()];
// conversions to character will not introduce new NAs
for (int i = 0; i < operand.getLength(); i++) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment