From 8dfc4d19c04b162a242c1c7aef78a884b854ce98 Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Thu, 23 Mar 2017 11:25:33 +0100
Subject: [PATCH] FastR Grid: small code clean-up in unit conversions

---
 .../r/library/fastrGrid/GridUtils.java        | 13 ++++
 .../truffle/r/library/fastrGrid/Unit.java     | 69 +++++++++----------
 2 files changed, 44 insertions(+), 38 deletions(-)

diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/GridUtils.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/GridUtils.java
index c7680a918f..b57eaaa994 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/GridUtils.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/GridUtils.java
@@ -50,6 +50,10 @@ final class GridUtils {
         return vec.getDataAt(idx % vec.getLength());
     }
 
+    static int getDataAtMod(RAbstractIntVector vec, int idx) {
+        return vec.getDataAt(idx % vec.getLength());
+    }
+
     @ExplodeLoop
     static int maxLength(UnitLengthNode unitLength, RAbstractVector... units) {
         int result = 0;
@@ -107,6 +111,15 @@ final class GridUtils {
         return (RList) value;
     }
 
+    static double getDoubleAt(RAbstractVector vector, int index) {
+        if (vector instanceof RAbstractDoubleVector) {
+            return ((RAbstractDoubleVector) vector).getDataAt(index);
+        } else if (vector instanceof RAbstractIntVector) {
+            return ((RAbstractIntVector) vector).getDataAt(index);
+        }
+        throw RError.error(RError.NO_CALLER, Message.GENERIC, "Unexpected non double/integer value");
+    }
+
     static double asDouble(Object val) {
         if (val instanceof Double) {
             return (double) val;
diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/Unit.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/Unit.java
index fb8112cb3e..4cb343dfbf 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/Unit.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastrGrid/Unit.java
@@ -13,16 +13,17 @@ package com.oracle.truffle.r.library.fastrGrid;
 
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.asAbstractContainer;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.asDouble;
+import static com.oracle.truffle.r.library.fastrGrid.GridUtils.asDoubleVector;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.asIntVector;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.asList;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.asListOrNull;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.fmax;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.fmin;
+import static com.oracle.truffle.r.library.fastrGrid.GridUtils.getDataAtMod;
+import static com.oracle.truffle.r.library.fastrGrid.GridUtils.getDoubleAt;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.hasRClass;
 import static com.oracle.truffle.r.library.fastrGrid.GridUtils.sum;
 import static com.oracle.truffle.r.library.fastrGrid.device.DrawingContext.INCH_TO_POINTS_FACTOR;
-import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue;
-import static com.oracle.truffle.r.nodes.builtin.casts.fluent.CastNodeBuilder.newCastBuilder;
 
 import java.util.function.BiFunction;
 
@@ -35,9 +36,7 @@ import com.oracle.truffle.r.library.fastrGrid.UnitFactory.UnitToInchesNodeGen;
 import com.oracle.truffle.r.library.fastrGrid.ViewPortTransform.GetViewPortTransformNode;
 import com.oracle.truffle.r.library.fastrGrid.device.DrawingContext;
 import com.oracle.truffle.r.library.fastrGrid.device.GridDevice;
-import com.oracle.truffle.r.nodes.attributes.GetFixedAttributeNode;
 import com.oracle.truffle.r.nodes.helpers.InheritsCheckNode;
-import com.oracle.truffle.r.nodes.unary.CastNode;
 import com.oracle.truffle.r.runtime.ArgumentsSignature;
 import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RError.Message;
@@ -51,7 +50,6 @@ import com.oracle.truffle.r.runtime.data.RList;
 import com.oracle.truffle.r.runtime.data.RNull;
 import com.oracle.truffle.r.runtime.data.model.RAbstractContainer;
 import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
 import com.oracle.truffle.r.runtime.nodes.RBaseNode;
@@ -116,6 +114,13 @@ public class Unit {
     private static final int L_minimising = 6;
     private static final int L_multiplying = 7;
 
+    // attributes in the unit object and unit classes
+    private static final String UNIT_ATTR_DATA = "data";
+    private static final String UNIT_ATTR_UNIT_ID = "valid.unit";
+    private static final String UNIT_CLASS = "unit";
+    private static final String UNIT_ARITHMETIC_CLASS = "unit.arithmetic";
+    private static final String UNIT_LIST_CLASS = "unit.list";
+
     private static final double CM_IN_INCH = 2.54;
 
     public static double inchesToCm(double inches) {
@@ -129,9 +134,9 @@ public class Unit {
     public static RAbstractDoubleVector newUnit(double value, int unitId) {
         assert unitId > 0 && unitId <= LAST_NORMAL_UNIT;
         RDoubleVector result = RDataFactory.createDoubleVector(new double[]{value}, RDataFactory.COMPLETE_VECTOR);
-        result.setClassAttr(RDataFactory.createStringVectorFromScalar("unit"));
-        result.setAttr("valid.unit", unitId);
-        result.setAttr("data", RNull.instance);
+        result.setClassAttr(RDataFactory.createStringVectorFromScalar(UNIT_CLASS));
+        result.setAttr(UNIT_ATTR_UNIT_ID, unitId);
+        result.setAttr(UNIT_ATTR_DATA, RNull.instance);
         return result;
     }
 
@@ -233,11 +238,11 @@ public class Unit {
     }
 
     static boolean isListUnit(Object unit) {
-        return unit instanceof RList && hasRClass((RAttributable) unit, "unit.list");
+        return unit instanceof RList && hasRClass((RAttributable) unit, UNIT_LIST_CLASS);
     }
 
     static boolean isArithmeticUnit(Object unit) {
-        return unit instanceof RList && hasRClass((RAttributable) unit, "unit.arithmetic");
+        return unit instanceof RList && hasRClass((RAttributable) unit, UNIT_ARITHMETIC_CLASS);
     }
 
     private static boolean isGrobUnit(int unitId) {
@@ -321,8 +326,8 @@ public class Unit {
     }
 
     abstract static class UnitNodeBase extends RBaseNode {
-        @Child private InheritsCheckNode inheritsArithmeticCheckNode = new InheritsCheckNode("unit.arithmetic");
-        @Child private InheritsCheckNode inheritsUnitListCheckNode = new InheritsCheckNode("unit.list");
+        @Child private InheritsCheckNode inheritsArithmeticCheckNode = new InheritsCheckNode(UNIT_ARITHMETIC_CLASS);
+        @Child private InheritsCheckNode inheritsUnitListCheckNode = new InheritsCheckNode(UNIT_LIST_CLASS);
 
         boolean isSimple(Object obj) {
             return !inheritsArithmeticCheckNode.execute(obj) && !inheritsUnitListCheckNode.execute(obj);
@@ -337,19 +342,6 @@ public class Unit {
         }
     }
 
-    public static final class UnitUnitIdNode extends Node {
-        @Child private GetFixedAttributeNode getValidUnitsAttr = GetFixedAttributeNode.create(VALID_UNIT_ATTR);
-
-        public int execute(RAbstractContainer unit, int index) {
-            RAbstractIntVector validUnits = asIntVector(getValidUnitsAttr.execute(unit));
-            return validUnits.getDataAt(index % validUnits.getLength());
-        }
-
-        public static UnitUnitIdNode create() {
-            return new UnitUnitIdNode();
-        }
-    }
-
     public static final class IsRelativeUnitNode extends UnitNodeBase {
         @Child private RGridCodeCall isPureNullCall = new RGridCodeCall("isPureNullUnit");
 
@@ -467,7 +459,7 @@ public class Unit {
      * interpreted as cyclic.
      */
     public abstract static class UnitToInchesNode extends UnitNodeBase {
-        @Child GrobUnitToInches grobUnitToInches = new GrobUnitToInches();
+        @Child GrobUnitToInches grobUnitToInches;
 
         public static UnitToInchesNode create() {
             return UnitToInchesNodeGen.create();
@@ -496,17 +488,14 @@ public class Unit {
         public abstract double execute(RAbstractContainer vector, int index, UnitConversionContext ctx, AxisOrDimension axisOrDim);
 
         @Specialization(guards = "isSimple(value)")
-        double doNormal(RAbstractContainer value, int index, UnitConversionContext ctx, AxisOrDimension axisOrDim,
-                        @Cached("createAsDoubleCast()") CastNode asDoubleCast,
-                        @Cached("create()") UnitUnitIdNode unitUnitId) {
-            int unitId = unitUnitId.execute(value, index);
-            RAbstractDoubleVector vector = (RAbstractDoubleVector) asDoubleCast.execute(value);
-            double scalarValue = vector.getDataAt(index % vector.getLength());
+        double doNormal(RAbstractVector value, int index, UnitConversionContext ctx, AxisOrDimension axisOrDim) {
+            int unitId = getDataAtMod(asIntVector(value.getAttr(UNIT_ATTR_UNIT_ID)), index);
+            double scalarValue = getDoubleAt(value, index % value.getLength());
             if (isGrobUnit(unitId)) {
-                RList grobList = asList(vector.getAttr("data"));
-                return grobUnitToInches.execute(scalarValue, unitId, grobList.getDataAt(index % grobList.getLength()), ctx);
+                RList grobList = asList(value.getAttr(UNIT_ATTR_DATA));
+                return getGrobUnitToInchesNode().execute(scalarValue, unitId, grobList.getDataAt(index % grobList.getLength()), ctx);
             }
-            return convertToInches(scalarValue, unitId, asListOrNull(vector.getAttr("data")), ctx, axisOrDim);
+            return convertToInches(scalarValue, unitId, asListOrNull(value.getAttr(UNIT_ATTR_DATA)), ctx, axisOrDim);
         }
 
         @Specialization(guards = "isUnitList(value)")
@@ -531,7 +520,7 @@ public class Unit {
                 case "-":
                     return recursive.apply(expr.arg1, L_subtracting) - recursive.apply(expr.arg2, L_subtracting);
                 case "*":
-                    RAbstractDoubleVector left = GridUtils.asDoubleVector(expr.arg1);
+                    RAbstractDoubleVector left = asDoubleVector(expr.arg1);
                     return left.getDataAt(index % left.getLength()) * recursive.apply(expr.arg2, L_multiplying);
                 default:
                     break;
@@ -575,8 +564,12 @@ public class Unit {
             return L_plain;
         }
 
-        static CastNode createAsDoubleCast() {
-            return newCastBuilder().mustBe(numericValue()).asDoubleVector().buildCastNode();
+        private GrobUnitToInches getGrobUnitToInchesNode() {
+            if (grobUnitToInches == null) {
+                CompilerDirectives.transferToInterpreterAndInvalidate();
+                grobUnitToInches = insert(new GrobUnitToInches());
+            }
+            return grobUnitToInches;
         }
     }
 
-- 
GitLab