Skip to content
Snippets Groups Projects
Commit e2cdb905 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert unary and binary (boolean) arithmetic nodes to VectorAccess

parent 9303cb46
No related branches found
No related tags found
No related merge requests found
......@@ -87,29 +87,27 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 {
}
@Specialization(limit = "CACHE_LIMIT", guards = {"cached != null", "cached.isSupported(left, right)"})
protected Object doNumericVectorCached(Object left, Object right,
protected Object doNumericVectorCached(RAbstractVector left, RAbstractVector right,
@Cached("createFastCached(left, right)") BinaryMapNode cached) {
return cached.apply(left, right);
}
@Specialization(replaces = "doNumericVectorCached", guards = {"isNumericVector(left)", "isNumericVector(right)"})
@TruffleBoundary
protected Object doNumericVectorGeneric(Object left, Object right,
protected Object doNumericVectorGeneric(RAbstractVector left, RAbstractVector right,
@Cached("binary.createOperation()") BinaryArithmetic arithmetic,
@Cached("new(createCached(arithmetic, left, right))") GenericNumericVectorNode generic) {
RAbstractVector leftVector = (RAbstractVector) left;
RAbstractVector rightVector = (RAbstractVector) right;
return generic.get(arithmetic, leftVector, rightVector).apply(leftVector, rightVector);
@Cached("createGeneric()") GenericNumericVectorNode generic) {
return generic.get(arithmetic, left, right).apply(left, right);
}
protected BinaryMapNode createFastCached(Object left, Object right) {
protected BinaryMapNode createFastCached(RAbstractVector left, RAbstractVector right) {
if (isNumericVector(left) && isNumericVector(right)) {
return createCached(binary.createOperation(), left, right);
return createCached(binary.createOperation(), left, right, false);
}
return null;
}
protected static boolean isNumericVector(Object value) {
protected static boolean isNumericVector(RAbstractVector value) {
return value instanceof RAbstractIntVector || value instanceof RAbstractDoubleVector || value instanceof RAbstractComplexVector || value instanceof RAbstractLogicalVector;
}
......@@ -133,19 +131,19 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 {
}
@Specialization(guards = {"isNumericVector(right)"})
protected Object doLeftNull(@SuppressWarnings("unused") RNull left, Object right,
protected Object doLeftNull(@SuppressWarnings("unused") RNull left, RAbstractVector right,
@Cached("createClassProfile()") ValueProfile classProfile) {
RType rType = ((RAbstractVector) classProfile.profile(right)).getRType();
if (rType == RType.Complex) {
RType type = classProfile.profile(right).getRType();
if (type == RType.Complex) {
return RDataFactory.createEmptyComplexVector();
} else {
if (rType == RType.Integer || rType == RType.Logical) {
if (type == RType.Integer || type == RType.Logical) {
if (operation instanceof BinaryArithmetic.Div || operation instanceof BinaryArithmetic.Pow) {
return RType.Double.getEmpty();
} else {
return RType.Integer.getEmpty();
}
} else if (rType == RType.Double) {
} else if (type == RType.Double) {
return RType.Double.getEmpty();
} else {
throw error(Message.NON_NUMERIC_BINARY);
......@@ -154,7 +152,7 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 {
}
@Specialization(guards = {"isNumericVector(left)"})
protected Object doRightNull(Object left, RNull right,
protected Object doRightNull(RAbstractVector left, RNull right,
@Cached("createClassProfile()") ValueProfile classProfile) {
return doLeftNull(right, left, classProfile);
}
......@@ -164,7 +162,7 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 {
throw error(Message.NON_NUMERIC_BINARY);
}
protected static BinaryMapNode createCached(BinaryArithmetic innerArithmetic, Object left, Object right) {
protected static BinaryMapNode createCached(BinaryArithmetic innerArithmetic, Object left, Object right, boolean isGeneric) {
RAbstractVector leftVector = (RAbstractVector) left;
RAbstractVector rightVector = (RAbstractVector) right;
......@@ -174,22 +172,22 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 {
resultType = RType.Double;
}
return BinaryMapNode.create(new BinaryMapArithmeticFunctionNode(innerArithmetic), leftVector, rightVector, argumentType, resultType, true);
return BinaryMapNode.create(new BinaryMapArithmeticFunctionNode(innerArithmetic), leftVector, rightVector, argumentType, resultType, true, isGeneric);
}
protected static GenericNumericVectorNode createGeneric() {
return new GenericNumericVectorNode();
}
protected static final class GenericNumericVectorNode extends TruffleBoundaryNode {
@Child private BinaryMapNode cached;
public GenericNumericVectorNode(BinaryMapNode cachedOperation) {
this.cached = insert(cachedOperation);
}
public BinaryMapNode get(BinaryArithmetic arithmetic, RAbstractVector left, RAbstractVector right) {
CompilerAsserts.neverPartOfCompilation();
BinaryMapNode map = cached;
if (!map.isSupported(left, right)) {
cached = map = map.replace(createCached(arithmetic, left, right));
if (map == null || !map.isSupported(left, right)) {
cached = map = insert(createCached(arithmetic, left, right, true));
}
return map;
}
......
......@@ -104,24 +104,22 @@ public abstract class BinaryBooleanNode extends RBuiltinNode.Arg2 {
}
@Specialization(limit = "CACHE_LIMIT", guards = {"cached != null", "cached.isSupported(left, right)"})
protected Object doNumericVectorCached(Object left, Object right,
protected Object doNumericVectorCached(RAbstractVector left, RAbstractVector right,
@Cached("createFastCached(left, right)") BinaryMapNode cached) {
return cached.apply(left, right);
}
@Specialization(replaces = "doNumericVectorCached", guards = "isSupported(left, right)")
@TruffleBoundary
protected Object doNumericVectorGeneric(Object left, Object right,
protected Object doNumericVectorGeneric(RAbstractVector left, RAbstractVector right,
@Cached("factory.createOperation()") BooleanOperation operation,
@Cached("new(createCached(operation, left, right))") GenericNumericVectorNode generic) {
RAbstractVector leftVector = (RAbstractVector) left;
RAbstractVector rightVector = (RAbstractVector) right;
return generic.get(operation, leftVector, rightVector).apply(leftVector, rightVector);
@Cached("createGeneric()") GenericNumericVectorNode generic) {
return generic.get(operation, left, right).apply(left, right);
}
protected BinaryMapNode createFastCached(Object left, Object right) {
if (isSupported(left, right)) {
return createCached(factory.createOperation(), left, right);
return createCached(factory.createOperation(), left, right, false);
}
return null;
}
......@@ -243,7 +241,7 @@ public abstract class BinaryBooleanNode extends RBuiltinNode.Arg2 {
throw error(Message.OPERATIONS_NUMERIC_LOGICAL_COMPLEX);
}
protected static BinaryMapNode createCached(BooleanOperation operation, Object left, Object right) {
protected static BinaryMapNode createCached(BooleanOperation operation, Object left, Object right, boolean isGeneric) {
RAbstractVector leftVector = (RAbstractVector) left;
RAbstractVector rightVector = (RAbstractVector) right;
......@@ -255,23 +253,24 @@ public abstract class BinaryBooleanNode extends RBuiltinNode.Arg2 {
resultType = RType.Logical;
}
return BinaryMapNode.create(new BinaryMapBooleanFunctionNode(operation), leftVector, rightVector, argumentType, resultType, false);
return BinaryMapNode.create(new BinaryMapBooleanFunctionNode(operation), leftVector, rightVector, argumentType, resultType, false, isGeneric);
}
protected static GenericNumericVectorNode createGeneric() {
return new GenericNumericVectorNode();
}
protected static final class GenericNumericVectorNode extends TruffleBoundaryNode {
@Child private BinaryMapNode cached;
public GenericNumericVectorNode(BinaryMapNode cachedOperation) {
this.cached = insert(cachedOperation);
}
private BinaryMapNode get(BooleanOperation arithmetic, RAbstractVector left, RAbstractVector right) {
CompilerAsserts.neverPartOfCompilation();
if (!cached.isSupported(left, right)) {
cached = cached.replace(createCached(arithmetic, left, right));
BinaryMapNode map = cached;
if (map == null || !map.isSupported(left, right)) {
cached = map = insert(createCached(arithmetic, left, right, true));
}
return cached;
return map;
}
}
}
......@@ -52,42 +52,45 @@ public abstract class UnaryArithmeticNode extends UnaryNode {
public abstract Object execute(Object value);
@Specialization(guards = {"cachedNode != null", "cachedNode.isSupported(operand)"})
protected Object doCached(Object operand,
protected Object doCached(RAbstractVector operand,
@Cached("createCachedFast(operand)") UnaryMapNode cachedNode) {
return cachedNode.apply(operand);
}
protected UnaryMapNode createCachedFast(Object operand) {
protected UnaryMapNode createCachedFast(RAbstractVector operand) {
if (isNumericVector(operand)) {
return createCached(unary.createOperation(), operand);
return createCached(unary.createOperation(), operand, false);
}
return null;
}
protected static UnaryMapNode createCached(UnaryArithmetic arithmetic, Object operand) {
protected static UnaryMapNode createCached(UnaryArithmetic arithmetic, Object operand, boolean isGeneric) {
if (operand instanceof RAbstractVector) {
RAbstractVector castOperand = (RAbstractVector) operand;
RType operandType = castOperand.getRType();
if (operandType.isNumeric()) {
RType type = RType.maxPrecedence(operandType, arithmetic.getMinPrecedence());
RType resultType = arithmetic.calculateResultType(type);
return UnaryMapNode.create(new ScalarUnaryArithmeticNode(arithmetic), castOperand, type, resultType);
return UnaryMapNode.create(new ScalarUnaryArithmeticNode(arithmetic), castOperand, type, resultType, isGeneric);
}
}
return null;
}
protected static boolean isNumericVector(Object value) {
protected static boolean isNumericVector(RAbstractVector value) {
return value instanceof RAbstractIntVector || value instanceof RAbstractDoubleVector || value instanceof RAbstractComplexVector || value instanceof RAbstractLogicalVector;
}
@Specialization(replaces = "doCached", guards = {"isNumericVector(operand)"})
@TruffleBoundary
protected Object doGeneric(Object operand,
protected Object doGeneric(RAbstractVector operand,
@Cached("unary.createOperation()") UnaryArithmetic arithmetic,
@Cached("new(createCached(arithmetic, operand))") GenericNumericVectorNode generic) {
RAbstractVector operandVector = (RAbstractVector) operand;
return generic.get(arithmetic, operandVector).apply(operandVector);
@Cached("createGeneric()") GenericNumericVectorNode generic) {
return generic.get(arithmetic, operand).apply(operand);
}
protected static GenericNumericVectorNode createGeneric() {
return new GenericNumericVectorNode();
}
@Override
......@@ -110,16 +113,12 @@ public abstract class UnaryArithmeticNode extends UnaryNode {
@Child private UnaryMapNode cached;
public GenericNumericVectorNode(UnaryMapNode cachedOperation) {
this.cached = cachedOperation;
}
public UnaryMapNode get(UnaryArithmetic arithmetic, RAbstractVector operand) {
UnaryMapNode next = cached;
if (!next.isSupported(operand)) {
next = cached.replace(createCached(arithmetic, operand));
UnaryMapNode map = cached;
if (map == null || !map.isSupported(operand)) {
cached = map = insert(createCached(arithmetic, operand, true));
}
return next;
return map;
}
}
}
......@@ -124,6 +124,7 @@ public abstract class BinaryArithmetic extends Operation {
public static final class Add extends BinaryArithmetic {
@CompilationFinal private boolean introducesOverflow = false;
@CompilationFinal private boolean introducesNA = false;
public Add() {
......@@ -137,17 +138,22 @@ public abstract class BinaryArithmetic extends Operation {
@Override
public boolean introducesNA() {
return introducesNA;
return introducesNA || introducesOverflow;
}
@Override
public int op(int left, int right) {
if (!introducesNA) {
if (!introducesOverflow) {
try {
return Math.addExact(left, right);
int result = Math.addExact(left, right);
// NAs can also be introduced without a 32-bit overflow
if (!introducesNA && result == RRuntime.INT_NA) {
CompilerDirectives.transferToInterpreterAndInvalidate();
introducesNA = true;
}
} catch (ArithmeticException e) {
CompilerDirectives.transferToInterpreterAndInvalidate();
introducesNA = true;
introducesOverflow = true;
}
}
// Borrowed from ExactMath
......@@ -176,6 +182,7 @@ public abstract class BinaryArithmetic extends Operation {
public static final class Subtract extends BinaryArithmetic {
@CompilationFinal private boolean introducesOverflow = false;
@CompilationFinal private boolean introducesNA = false;
public Subtract() {
......@@ -189,17 +196,23 @@ public abstract class BinaryArithmetic extends Operation {
@Override
public boolean introducesNA() {
return introducesNA;
return introducesNA || introducesOverflow;
}
@Override
public int op(int left, int right) {
if (!introducesNA) {
if (!introducesOverflow) {
try {
return Math.subtractExact(left, right);
int result = Math.subtractExact(left, right);
// NAs can also be introduced without a 32-bit overflow
if (!introducesNA && result == RRuntime.INT_NA) {
CompilerDirectives.transferToInterpreterAndInvalidate();
introducesNA = true;
}
return result;
} catch (ArithmeticException e) {
CompilerDirectives.transferToInterpreterAndInvalidate();
introducesNA = true;
introducesOverflow = true;
}
}
// Borrowed from ExactMath
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment