From 75a2b49498014ccffa68986442575be8b2071768 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Mon, 24 Apr 2017 15:12:58 +0200
Subject: [PATCH] allocate less memory in solve.default

---
 .../r/nodes/builtin/base/LaFunctions.java     | 58 +++++++++++++++----
 1 file changed, 47 insertions(+), 11 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java
index d1515e966d..0756c474f9 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java
@@ -29,6 +29,8 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.READS_STATE;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
 
+import java.lang.ref.SoftReference;
+import java.lang.ref.WeakReference;
 import java.util.function.Function;
 
 import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
@@ -561,10 +563,31 @@ public class LaFunctions {
         }
     }
 
+    private static final class NativeArrayCache {
+        private final ThreadLocal<SoftReference<double[]>> cache = new ThreadLocal<>();
+
+        @TruffleBoundary
+        private double[] get(int minLength) {
+            SoftReference<double[]> cached = cache.get();
+            double[] array;
+            if (cached != null) {
+                array = cached.get();
+                if (array != null && array.length >= minLength) {
+                    return array;
+                }
+            }
+            array = new double[minLength];
+            cache.set(new SoftReference<>(array));
+            return array;
+        }
+    }
+
     @RBuiltin(name = "La_solve", kind = INTERNAL, parameterNames = {"a", "bin", "tolin"}, behavior = PURE)
     public abstract static class LaSolve extends RBuiltinNode.Arg3 {
         @Child private CastDoubleNode castDouble = CastDoubleNodeGen.create(false, false, false);
 
+        private static final NativeArrayCache aCache = new NativeArrayCache();
+
         private static Function<RAbstractDoubleVector, Object> getDimVal(int dim) {
             return vec -> vec.getDimensions()[dim];
         }
@@ -605,8 +628,8 @@ public class LaFunctions {
             int p;
             double[] bData;
             RDoubleVector b;
-            if (bin.isMatrix()) {
-                int[] bDims = getBinDimsNode.getDimensions(bin);
+            int[] bDims = getBinDimsNode.getDimensions(bin);
+            if (GetDimAttributeNode.isMatrix(bDims)) {
                 p = bDims[1];
                 if (p == 0) {
                     throw error(Message.GENERIC, "no right-hand side in 'b'");
@@ -615,7 +638,12 @@ public class LaFunctions {
                 if (p2 != n) {
                     throw error(Message.MUST_BE_SQUARE_COMPATIBLE, "b", p2, p, "a", n, n);
                 }
-                bData = new double[n * p];
+                if (bin.getLength() == n * p) {
+                    bData = bin.materialize().getDataNonShared();
+                } else {
+                    bData = new double[n];
+                    System.arraycopy(bin.getInternalStore(), 0, bData, 0, n * p);
+                }
                 b = RDataFactory.createDoubleVector(bData, RDataFactory.COMPLETE_VECTOR);
                 setBDimsNode.setDimensions(b, new int[]{n, p});
                 RList binDn = getBinDimNamesNode.getDimNames(bin);
@@ -634,24 +662,32 @@ public class LaFunctions {
                 if (bin.getLength() != n) {
                     throw error(Message.MUST_BE_SQUARE_COMPATIBLE, "b", bin.getLength(), p, "a", n, n);
                 }
-                bData = new double[n];
+                if (bin.getLength() == n) {
+                    bData = bin.materialize().getDataNonShared();
+                } else {
+                    bData = new double[n];
+                    System.arraycopy(bin.getInternalStore(), 0, bData, 0, n * p);
+                }
                 b = RDataFactory.createDoubleVector(bData, RDataFactory.COMPLETE_VECTOR);
                 if (aDn != null) {
                     setNamesNode.setNames(b, RDataFactory.createStringVector((String) aDn.getDataAt(1)));
                 }
             }
 
-            System.arraycopy(bin.getInternalStore(), 0, bData, 0, n * p);
-
             int[] ipiv = new int[n];
             // work on a copy of A
-            double[] avals = new double[n * n];
+            RDoubleVector aDouble;
             if (a instanceof RAbstractDoubleVector) {
-                System.arraycopy(a.getInternalStore(), 0, avals, 0, n * n);
+                aDouble = ((RAbstractDoubleVector) a).materialize();
+            } else {
+                aDouble = (RDoubleVector) castDouble.doCast(a);
+            }
+            double[] avals;
+            if (aDouble.isShared()) {
+                avals = aCache.get(aDouble.getLength());
+                System.arraycopy(aDouble.getDataWithoutCopying(), 0, avals, 0, n * p);
             } else {
-                RDoubleVector aDouble = (RDoubleVector) castDouble.doCast(a);
-                assert aDouble != a;
-                avals = aDouble.getInternalStore();
+                avals = aDouble.getDataWithoutCopying();
             }
             int info = dgesvNode.execute(n, p, avals, n, ipiv, bData, n);
             if (info < 0) {
-- 
GitLab