Skip to content
Snippets Groups Projects
Commit aecc38e0 authored by Zbynek Slajchrt's avatar Zbynek Slajchrt
Browse files

Gamma and bitwise builtins refactored using CP

parent 2467293b
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,8 @@ import static com.oracle.truffle.r.library.stats.StatsUtil.DBL_MIN_EXP;
import static com.oracle.truffle.r.library.stats.StatsUtil.M_LOG10_2;
import static com.oracle.truffle.r.library.stats.StatsUtil.M_PI;
import static com.oracle.truffle.r.library.stats.StatsUtil.fmax2;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue;
import static com.oracle.truffle.r.runtime.RDispatch.MATH_GROUP_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
......@@ -26,11 +28,11 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.library.stats.GammaFunctions;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.BaseGammaFunctionsFactory.DpsiFnCalcNodeGen;
import com.oracle.truffle.r.runtime.RError;
......@@ -39,7 +41,6 @@ import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.closures.RClosures;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
......@@ -71,6 +72,11 @@ public class BaseGammaFunctions {
private final NACheck naValCheck = NACheck.create();
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("x").mustBe(complexValue().not(), RError.Message.UNIMPLEMENTED_COMPLEX_FUN).mustBe(numericValue(), RError.Message.NON_NUMERIC_MATH).asDoubleVector();
}
@Specialization
protected RDoubleVector lgamma(RAbstractDoubleVector x) {
naValCheck.enable(true);
......@@ -93,17 +99,6 @@ public class BaseGammaFunctions {
return lgamma(RClosures.createLogicalToDoubleVector(x));
}
@Specialization
@TruffleBoundary
protected Object lgamma(@SuppressWarnings("unused") RAbstractComplexVector x) {
return RError.error(this, RError.Message.UNIMPLEMENTED_COMPLEX_FUN);
}
@Fallback
@TruffleBoundary
protected Object lgamma(@SuppressWarnings("unused") Object x) {
throw RError.error(this, RError.Message.NON_NUMERIC_MATH);
}
}
@RBuiltin(name = "digamma", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE)
......@@ -121,6 +116,11 @@ public class BaseGammaFunctions {
return dpsiFnCalc.executeDouble(x, n, kode, ans);
}
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("x").mustBe(complexValue().not(), RError.Message.UNIMPLEMENTED_COMPLEX_FUN).mustBe(numericValue(), RError.Message.NON_NUMERIC_MATH).asDoubleVector();
}
@Specialization
protected RDoubleVector digamma(RAbstractDoubleVector x) {
naValCheck.enable(x);
......@@ -156,17 +156,6 @@ public class BaseGammaFunctions {
return digamma(RClosures.createLogicalToDoubleVector(x));
}
@Specialization
@TruffleBoundary
protected Object digamma(@SuppressWarnings("unused") RAbstractComplexVector x) {
return RError.error(this, RError.Message.UNIMPLEMENTED_COMPLEX_FUN);
}
@Fallback
@TruffleBoundary
protected Object digamma(@SuppressWarnings("unused") Object x) {
throw RError.error(this, RError.Message.NON_NUMERIC_MATH);
}
}
@NodeChildren({@NodeChild(value = "x"), @NodeChild(value = "n"), @NodeChild(value = "kode"), @NodeChild(value = "ans")})
......
......@@ -11,25 +11,27 @@
package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import java.util.function.Function;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import com.oracle.truffle.r.nodes.binary.CastTypeNode;
import com.oracle.truffle.r.nodes.binary.CastTypeNodeGen;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.unary.TypeofNode;
import com.oracle.truffle.r.nodes.unary.TypeofNodeGen;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.RType;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
......@@ -37,13 +39,10 @@ public class BitwiseFunctions {
public abstract static class BasicBitwise extends RBuiltinNode {
private final BranchProfile errorProfile = BranchProfile.create();
private final NACheck naCheckA = NACheck.create();
private final NACheck naCheckB = NACheck.create();
private final LoopConditionProfile loopProfile = LoopConditionProfile.createCountingProfile();
@Child private CastTypeNode castTypeA = CastTypeNodeGen.create(null, null);
@Child private CastTypeNode castTypeB = CastTypeNodeGen.create(null, null);
@Child private TypeofNode typeofA = TypeofNodeGen.create();
@Child private TypeofNode typeofB = TypeofNodeGen.create();
......@@ -62,10 +61,7 @@ public class BitwiseFunctions {
}
}
protected Object basicBit(RAbstractVector a, RAbstractVector b, Operation op) {
checkBasicBit(a, b, op);
RAbstractIntVector aVec = (RAbstractIntVector) castTypeA.execute(a, RType.Integer);
RAbstractIntVector bVec = (RAbstractIntVector) castTypeB.execute(b, RType.Integer);
protected Object basicBit(RAbstractIntVector aVec, RAbstractIntVector bVec, Operation op) {
naCheckA.enable(aVec);
naCheckB.enable(bVec);
int aLen = aVec.getLength();
......@@ -119,8 +115,7 @@ public class BitwiseFunctions {
return RDataFactory.createIntVector(ans, completeVector);
}
protected Object bitNot(RAbstractVector a) {
RAbstractIntVector aVec = (RAbstractIntVector) castTypeA.execute(a, RType.Integer);
protected Object bitNot(RAbstractIntVector aVec) {
int[] ans = new int[aVec.getLength()];
for (int i = 0; i < aVec.getLength(); i++) {
ans[i] = ~aVec.getDataAt(i);
......@@ -136,107 +131,144 @@ public class BitwiseFunctions {
return RDataFactory.createIntVector(na, RDataFactory.INCOMPLETE_VECTOR);
}
protected void checkBasicBit(RAbstractVector a, RAbstractVector b, Operation op) {
hasSameTypes(a, b);
hasSupportedType(a, op);
}
protected void checkShiftOrNot(RAbstractVector a, Operation op) {
hasSupportedType(a, op);
}
protected void hasSameTypes(RAbstractVector a, RAbstractVector b) {
RType aType = typeofA.execute(a);
RType bType = typeofB.execute(b);
boolean aCorrectType = (aType == RType.Integer || aType == RType.Double) ? true : false;
boolean bCorrectType = (bType == RType.Integer || bType == RType.Double) ? true : false;
if ((aCorrectType && bCorrectType) || aType == bType) {
return;
} else {
errorProfile.enter();
throw RError.error(this, RError.Message.SAME_TYPE, "a", "b");
}
}
protected void hasSupportedType(RAbstractVector a, Operation op) {
if (!(a instanceof RAbstractIntVector) && !(a instanceof RAbstractDoubleVector)) {
errorProfile.enter();
String type = typeofA.execute(a).getName();
throw RError.error(this, RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, type, op.name);
}
protected Function<Object, String> getArgType() {
return x -> typeofA.execute(x).getName();
}
protected boolean shiftByCharacter(RAbstractVector n) {
return typeofB.execute(n) == RType.Character;
}
}
@RBuiltin(name = "bitwiseAnd", kind = INTERNAL, parameterNames = {"a", "b"}, behavior = PURE)
public abstract static class BitwiseAnd extends BasicBitwise {
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("a").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.AND.name).asIntegerVector();
casts.arg("b").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.AND.name).asIntegerVector();
}
@Specialization
protected Object bitwAnd(RAbstractVector a, RAbstractVector b) {
protected Object bitwAnd(RAbstractIntVector a, RAbstractIntVector b) {
return basicBit(a, b, Operation.AND);
}
@Fallback
@SuppressWarnings("unused")
protected Object differentTypes(Object a, Object b) {
throw RError.error(this, RError.Message.SAME_TYPE, "a", "b");
}
}
@RBuiltin(name = "bitwiseOr", kind = INTERNAL, parameterNames = {"a", "b"}, behavior = PURE)
public abstract static class BitwiseOr extends BasicBitwise {
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("a").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.OR.name).asIntegerVector();
casts.arg("b").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.OR.name).asIntegerVector();
}
@Specialization
protected Object bitwOr(RAbstractVector a, RAbstractVector b) {
protected Object bitwOr(RAbstractIntVector a, RAbstractIntVector b) {
return basicBit(a, b, Operation.OR);
}
@Fallback
@SuppressWarnings("unused")
protected Object differentTypes(Object a, Object b) {
throw RError.error(this, RError.Message.SAME_TYPE, "a", "b");
}
}
@RBuiltin(name = "bitwiseXor", kind = INTERNAL, parameterNames = {"a", "b"}, behavior = PURE)
public abstract static class BitwiseXor extends BasicBitwise {
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("a").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.XOR.name).asIntegerVector();
casts.arg("b").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.XOR.name).asIntegerVector();
}
@Specialization
protected Object bitwXor(RAbstractVector a, RAbstractVector b) {
protected Object bitwXor(RAbstractIntVector a, RAbstractIntVector b) {
return basicBit(a, b, Operation.XOR);
}
@Fallback
@SuppressWarnings("unused")
protected Object differentTypes(Object a, Object b) {
throw RError.error(this, RError.Message.SAME_TYPE, "a", "b");
}
}
@RBuiltin(name = "bitwiseShiftR", kind = INTERNAL, parameterNames = {"a", "n"}, behavior = PURE)
public abstract static class BitwiseShiftR extends BasicBitwise {
@Specialization(guards = {"!shiftByCharacter(n)"})
protected Object bitwShiftR(RAbstractVector a, RAbstractVector n) {
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("a").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.SHIFTR.name).asIntegerVector();
casts.arg("n").mapIf(stringValue(), asStringVector(), asIntegerVector());
}
@Specialization
protected Object bitwShiftR(RAbstractIntVector a, RAbstractIntVector n) {
return basicBit(a, n, Operation.SHIFTR);
}
@Specialization(guards = {"shiftByCharacter(n)"})
@Specialization
@SuppressWarnings("unused")
protected Object bitwShiftR(RAbstractIntVector a, RNull n) {
return RDataFactory.createEmptyIntVector();
}
@Specialization
@SuppressWarnings("unused")
protected Object bitwShiftRChar(RAbstractVector a, RAbstractVector n) {
checkShiftOrNot(a, Operation.SHIFTR);
protected Object bitwShiftRChar(RAbstractIntVector a, RAbstractStringVector n) {
return makeNA(a.getLength());
}
}
@RBuiltin(name = "bitwiseShiftL", kind = INTERNAL, parameterNames = {"a", "n"}, behavior = PURE)
public abstract static class BitwiseShiftL extends BasicBitwise {
@Specialization(guards = {"!shiftByCharacter(n)"})
protected Object bitwShiftR(RAbstractVector a, RAbstractVector n) {
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("a").mustBe(doubleValue().or(integerValue()), RError.ROOTNODE, RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.SHIFTL.name).asIntegerVector();
casts.arg("n").mapIf(stringValue(), chain(asStringVector()).with(shouldBe(anyValue().not(), RError.SHOW_CALLER, false, RError.Message.NA_INTRODUCED_COERCION)).end(), asIntegerVector());
}
@Specialization
protected Object bitwShiftL(RAbstractIntVector a, RAbstractIntVector n) {
return basicBit(a, n, Operation.SHIFTL);
}
@Specialization(guards = {"shiftByCharacter(n)"})
@Specialization
@SuppressWarnings("unused")
protected Object bitwShiftRChar(RAbstractVector a, RAbstractVector n) {
checkShiftOrNot(a, Operation.SHIFTL);
protected Object bitwShiftL(RAbstractIntVector a, RNull n) {
return RDataFactory.createEmptyIntVector();
}
@Specialization
@SuppressWarnings("unused")
protected Object bitwShiftLChar(RAbstractVector a, RAbstractStringVector n) {
return makeNA(a.getLength());
}
}
@RBuiltin(name = "bitwiseNot", kind = INTERNAL, parameterNames = {"a"}, behavior = PURE)
public abstract static class BitwiseNot extends BasicBitwise {
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("a").mustBe(doubleValue().or(integerValue()), RError.Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, getArgType(), Operation.NOT.name).asIntegerVector();
}
@Specialization
protected Object bitwNot(RAbstractVector a) {
checkShiftOrNot(a, Operation.NOT);
protected Object bitwNot(RAbstractIntVector a) {
return bitNot(a);
}
}
}
......@@ -9403,6 +9403,10 @@ Warning message:
In bitwShiftL(c(3, 2, 4), c(3 + (0+3i))) :
imaginary parts discarded in coercion
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_bitwiseShiftL.testBitwiseFunctions
#{ bitwShiftL(c(8,4,2), NULL) }
integer(0)
##com.oracle.truffle.r.test.builtins.TestBuiltin_bitwiseShiftR.testBitwiseFunctions
#{ bitwShiftR(c(1,2,3,4), c("Hello")) }
[1] NA NA NA NA
......@@ -37,6 +37,6 @@ public class TestBuiltin_bitwiseAnd extends TestBase {
assertEval(Ignored.Unknown, "{ bitwAnd(NULL, NULL) }");
assertEval(Ignored.Unknown, "{ bitwAnd(c(), c(1,2,3)) }");
// Error message mismatch
assertEval(Output.IgnoreErrorContext, "{ bitwAnd(c(1,2,3,4), c(TRUE)) }");
assertEval(Output.IgnoreErrorMessage, "{ bitwAnd(c(1,2,3,4), c(TRUE)) }");
}
}
......@@ -27,6 +27,6 @@ public class TestBuiltin_bitwiseOr extends TestBase {
assertEval("{ bitwOr(c(10,11,12,13,14,15), c(1,1,1,1,1,1)) }");
assertEval("{ bitwOr(c(25,57,66), c(10,20,30,40,50,60)) }");
// Error message mismatch
assertEval(Output.IgnoreErrorContext, "{ bitwOr(c(1,2,3,4), c(3+3i)) }");
assertEval(Output.IgnoreErrorMessage, "{ bitwOr(c(1,2,3,4), c(3+3i)) }");
}
}
......@@ -27,14 +27,15 @@ public class TestBuiltin_bitwiseShiftL extends TestBase {
assertEval("{ bitwShiftL(c(10,11,12,13,14,15), c(1,1,1,1,1,1)) }");
assertEval("{ bitwShiftL(c(100,200,300), 1) }");
assertEval("{ bitwShiftL(c(25,57,66), c(10,20,30,40,50,60)) }");
assertEval("{ bitwShiftL(c(8,4,2), NULL) }");
assertEval(Output.IgnoreErrorContext, "{ bitwShiftL(TRUE, c(TRUE, FALSE)) }");
assertEval("{ bitwShiftL(TRUE, c(TRUE, FALSE)) }");
// Error message mismatch
assertEval(Ignored.Unknown, Output.IgnoreErrorContext, "{ bitwShiftL(c(3+3i), c(3,2,4)) }");
// Warning message mismatch
assertEval(Ignored.Unknown, "{ bitwShiftL(c(3,2,4), c(3+3i)) }");
// No warning message printed for NAs produced by coercion
assertEval(Ignored.Unknown, "{ bitwShiftL(c(1,2,3,4), c(\"a\")) }");
assertEval("{ bitwShiftL(c(1,2,3,4), c(\"a\")) }");
}
}
......@@ -28,6 +28,6 @@ public class TestBuiltin_bitwiseXor extends TestBase {
assertEval("{ bitwXor(c(25,57,66), c(10,20,30,40,50,60)) }");
assertEval("{ bitwXor(20,30) }");
assertEval(Output.IgnoreErrorContext, "{ bitwXor(c(\"r\"), c(16,17)) }");
assertEval(Output.IgnoreErrorMessage, "{ bitwXor(c(\"r\"), c(16,17)) }");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment