Skip to content
Snippets Groups Projects
Commit 3aadff8d authored by Lukas Stadler's avatar Lukas Stadler
Browse files

prevent NaN to 0L conversion in subset/subscript specials

parent 076cb9fe
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
*/ */
package com.oracle.truffle.r.nodes.builtin.base.infix; package com.oracle.truffle.r.nodes.builtin.base.infix;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren; import com.oracle.truffle.api.dsl.NodeChildren;
...@@ -30,7 +31,6 @@ import com.oracle.truffle.api.dsl.TypeSystemReference; ...@@ -30,7 +31,6 @@ import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.NodeCost; import com.oracle.truffle.api.nodes.NodeCost;
import com.oracle.truffle.api.nodes.NodeInfo; import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
...@@ -176,14 +176,6 @@ class SpecialsUtils { ...@@ -176,14 +176,6 @@ class SpecialsUtils {
@TypeSystemReference(EmptyTypeSystemFlatLayout.class) @TypeSystemReference(EmptyTypeSystemFlatLayout.class)
public abstract static class ConvertIndex extends RNode { public abstract static class ConvertIndex extends RNode {
private final boolean isSubset;
private final ConditionProfile zeroProfile;
ConvertIndex(boolean isSubset) {
this.isSubset = isSubset;
this.zeroProfile = isSubset ? null : ConditionProfile.createBinaryProfile();
}
protected abstract RNode getDelegate(); protected abstract RNode getDelegate();
@Specialization @Specialization
...@@ -191,14 +183,20 @@ class SpecialsUtils { ...@@ -191,14 +183,20 @@ class SpecialsUtils {
return value; return value;
} }
@Specialization @Specialization(rewriteOn = IllegalArgumentException.class)
protected int convertDouble(double value) { protected int convertDouble(double value) {
// Conversion from double to an index differs in subscript and subset.
int intValue = (int) value; int intValue = (int) value;
if (isSubset) { if (intValue == 0) {
return intValue; /*
* Conversion from double to an index differs in subscript and subset for values in
* the ]0..1[ range (subscript interprets 0.1 as 1, whereas subset treats it as 0).
* We avoid this special case by simply going to the more generic case for this
* range. Additionally, (int) Double.NaN is 0, which is also caught by this case.
*/
CompilerDirectives.transferToInterpreterAndInvalidate();
throw new IllegalArgumentException();
} else { } else {
return zeroProfile.profile(intValue == 0) ? (value == 0 ? 0 : 1) : intValue; return intValue;
} }
} }
...@@ -217,11 +215,7 @@ class SpecialsUtils { ...@@ -217,11 +215,7 @@ class SpecialsUtils {
return new ProfiledValue(value); return new ProfiledValue(value);
} }
public static ConvertIndex convertSubscript(RNode value) { public static ConvertIndex convertIndex(RNode value) {
return ConvertIndexNodeGen.create(false, value); return ConvertIndexNodeGen.create(value);
}
public static ConvertIndex convertSubset(RNode value) {
return ConvertIndexNodeGen.create(true, value);
} }
} }
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
*/ */
package com.oracle.truffle.r.nodes.builtin.base.infix; package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubscript; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertIndex;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
...@@ -193,9 +193,9 @@ public abstract class Subscript extends RBuiltinNode { ...@@ -193,9 +193,9 @@ public abstract class Subscript extends RBuiltinNode {
public static RNode special(ArgumentsSignature signature, RNode[] arguments, boolean inReplacement) { public static RNode special(ArgumentsSignature signature, RNode[] arguments, boolean inReplacement) {
if (signature.getNonNullCount() == 0) { if (signature.getNonNullCount() == 0) {
if (arguments.length == 2) { if (arguments.length == 2) {
return SubscriptSpecialNodeGen.create(inReplacement, profile(arguments[0]), convertSubscript(arguments[1])); return SubscriptSpecialNodeGen.create(inReplacement, profile(arguments[0]), convertIndex(arguments[1]));
} else if (arguments.length == 3) { } else if (arguments.length == 3) {
return SubscriptSpecial2NodeGen.create(inReplacement, profile(arguments[0]), convertSubscript(arguments[1]), convertSubscript(arguments[2])); return SubscriptSpecial2NodeGen.create(inReplacement, profile(arguments[0]), convertIndex(arguments[1]), convertIndex(arguments[2]));
} }
} }
return null; return null;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
*/ */
package com.oracle.truffle.r.nodes.builtin.base.infix; package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubset; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertIndex;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
...@@ -81,7 +81,7 @@ abstract class SubsetSpecial extends SubscriptSpecialBase { ...@@ -81,7 +81,7 @@ abstract class SubsetSpecial extends SubscriptSpecialBase {
} }
@Specialization(guards = {"simpleVector(vector)", "!inReplacement"}) @Specialization(guards = {"simpleVector(vector)", "!inReplacement"})
protected static Object access(VirtualFrame frame, RAbstractVector vector, Object index, protected Object access(VirtualFrame frame, RAbstractVector vector, Object index,
@Cached("createAccess()") ExtractVectorNode extract) { @Cached("createAccess()") ExtractVectorNode extract) {
return extract.apply(frame, vector, new Object[]{index}, RRuntime.LOGICAL_TRUE, RLogical.TRUE); return extract.apply(frame, vector, new Object[]{index}, RRuntime.LOGICAL_TRUE, RLogical.TRUE);
} }
...@@ -124,11 +124,11 @@ public abstract class Subset extends RBuiltinNode { ...@@ -124,11 +124,11 @@ public abstract class Subset extends RBuiltinNode {
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) { public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (signature.getNonNullCount() == 0 && (args.length == 2 || args.length == 3)) { if (signature.getNonNullCount() == 0 && (args.length == 2 || args.length == 3)) {
ProfiledValue profiledVector = profile(args[0]); ProfiledValue profiledVector = profile(args[0]);
ConvertIndex index = convertSubset(args[1]); ConvertIndex index = convertIndex(args[1]);
if (args.length == 2) { if (args.length == 2) {
return SubsetSpecialNodeGen.create(inReplacement, profiledVector, index); return SubsetSpecialNodeGen.create(inReplacement, profiledVector, index);
} else { } else {
return SubsetSpecial2NodeGen.create(inReplacement, profiledVector, index, convertSubset(args[2])); return SubsetSpecial2NodeGen.create(inReplacement, profiledVector, index, convertIndex(args[2]));
} }
} }
return null; return null;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
*/ */
package com.oracle.truffle.r.nodes.builtin.base.infix; package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubscript; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertIndex;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
...@@ -168,11 +168,11 @@ public abstract class UpdateSubscript extends RBuiltinNode { ...@@ -168,11 +168,11 @@ public abstract class UpdateSubscript extends RBuiltinNode {
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) { public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) { if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) {
ProfiledValue vector = profile(args[0]); ProfiledValue vector = profile(args[0]);
ConvertIndex index = convertSubscript(args[1]); ConvertIndex index = convertIndex(args[1]);
if (args.length == 3) { if (args.length == 3) {
return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]); return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]);
} else { } else {
return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertSubscript(args[2]), args[3]); return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertIndex(args[2]), args[3]);
} }
} }
return null; return null;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
*/ */
package com.oracle.truffle.r.nodes.builtin.base.infix; package com.oracle.truffle.r.nodes.builtin.base.infix;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertSubset; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.convertIndex;
import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile; import static com.oracle.truffle.r.nodes.builtin.base.infix.SpecialsUtils.profile;
import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
...@@ -57,11 +57,11 @@ public abstract class UpdateSubset extends RBuiltinNode { ...@@ -57,11 +57,11 @@ public abstract class UpdateSubset extends RBuiltinNode {
public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) { public static RNode special(ArgumentsSignature signature, RNode[] args, boolean inReplacement) {
if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) { if (SpecialsUtils.isCorrectUpdateSignature(signature) && (args.length == 3 || args.length == 4)) {
ProfiledValue vector = profile(args[0]); ProfiledValue vector = profile(args[0]);
ConvertIndex index = convertSubset(args[1]); ConvertIndex index = convertIndex(args[1]);
if (args.length == 3) { if (args.length == 3) {
return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]); return UpdateSubscriptSpecialNodeGen.create(inReplacement, vector, index, args[2]);
} else { } else {
return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertSubset(args[2]), args[3]); return UpdateSubscriptSpecial2NodeGen.create(inReplacement, vector, index, convertIndex(args[2]), args[3]);
} }
} }
return null; return null;
......
...@@ -234,7 +234,7 @@ public class SpecialCallTest extends TestBase { ...@@ -234,7 +234,7 @@ public class SpecialCallTest extends TestBase {
assertCallCounts("a <- c(1,2,3,4)", "a[[4]] <- 1", 1, 0, 2, 0); assertCallCounts("a <- c(1,2,3,4)", "a[[4]] <- 1", 1, 0, 2, 0);
assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0); assertCallCounts("a <- list(c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0);
assertCallCounts("a <- list(a=c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0); assertCallCounts("a <- list(a=c(1,2,3,4),2,3)", "a[[1]] <- 1", 1, 0, 2, 0);
assertCallCounts("a <- c(1,2,3,4)", "a[[0.1]] <- 1", 1, 0, 2, 0); assertCallCounts("a <- c(1,2,3,4)", "a[[0.1]] <- 1", 1, 0, 1, 1);
assertCallCounts("a <- c(1,2,3,4)", "a[[5]] <- 1", 1, 0, 1, 1); assertCallCounts("a <- c(1,2,3,4)", "a[[5]] <- 1", 1, 0, 1, 1);
assertCallCounts("a <- c(1,2,3,4)", "a[[0]] <- 1", 1, 0, 1, 1); assertCallCounts("a <- c(1,2,3,4)", "a[[0]] <- 1", 1, 0, 1, 1);
assertCallCounts("a <- c(1,2,3,4); b <- -1", "a[[b]] <- 1", 1, 0, 1, 1); assertCallCounts("a <- c(1,2,3,4); b <- -1", "a[[b]] <- 1", 1, 0, 1, 1);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment