From 75a2b49498014ccffa68986442575be8b2071768 Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Mon, 24 Apr 2017 15:12:58 +0200 Subject: [PATCH] allocate less memory in solve.default --- .../r/nodes/builtin/base/LaFunctions.java | 58 +++++++++++++++---- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java index d1515e966d..0756c474f9 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java @@ -29,6 +29,8 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.READS_STATE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; +import java.lang.ref.SoftReference; +import java.lang.ref.WeakReference; import java.util.function.Function; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; @@ -561,10 +563,31 @@ public class LaFunctions { } } + private static final class NativeArrayCache { + private final ThreadLocal<SoftReference<double[]>> cache = new ThreadLocal<>(); + + @TruffleBoundary + private double[] get(int minLength) { + SoftReference<double[]> cached = cache.get(); + double[] array; + if (cached != null) { + array = cached.get(); + if (array != null && array.length >= minLength) { + return array; + } + } + array = new double[minLength]; + cache.set(new SoftReference<>(array)); + return array; + } + } + @RBuiltin(name = "La_solve", kind = INTERNAL, parameterNames = {"a", "bin", "tolin"}, behavior = PURE) public abstract static class LaSolve extends RBuiltinNode.Arg3 { @Child private CastDoubleNode castDouble = CastDoubleNodeGen.create(false, false, false); + private static final NativeArrayCache aCache = new NativeArrayCache(); + private static Function<RAbstractDoubleVector, Object> getDimVal(int dim) { return vec -> vec.getDimensions()[dim]; } @@ -605,8 +628,8 @@ public class LaFunctions { int p; double[] bData; RDoubleVector b; - if (bin.isMatrix()) { - int[] bDims = getBinDimsNode.getDimensions(bin); + int[] bDims = getBinDimsNode.getDimensions(bin); + if (GetDimAttributeNode.isMatrix(bDims)) { p = bDims[1]; if (p == 0) { throw error(Message.GENERIC, "no right-hand side in 'b'"); @@ -615,7 +638,12 @@ public class LaFunctions { if (p2 != n) { throw error(Message.MUST_BE_SQUARE_COMPATIBLE, "b", p2, p, "a", n, n); } - bData = new double[n * p]; + if (bin.getLength() == n * p) { + bData = bin.materialize().getDataNonShared(); + } else { + bData = new double[n]; + System.arraycopy(bin.getInternalStore(), 0, bData, 0, n * p); + } b = RDataFactory.createDoubleVector(bData, RDataFactory.COMPLETE_VECTOR); setBDimsNode.setDimensions(b, new int[]{n, p}); RList binDn = getBinDimNamesNode.getDimNames(bin); @@ -634,24 +662,32 @@ public class LaFunctions { if (bin.getLength() != n) { throw error(Message.MUST_BE_SQUARE_COMPATIBLE, "b", bin.getLength(), p, "a", n, n); } - bData = new double[n]; + if (bin.getLength() == n) { + bData = bin.materialize().getDataNonShared(); + } else { + bData = new double[n]; + System.arraycopy(bin.getInternalStore(), 0, bData, 0, n * p); + } b = RDataFactory.createDoubleVector(bData, RDataFactory.COMPLETE_VECTOR); if (aDn != null) { setNamesNode.setNames(b, RDataFactory.createStringVector((String) aDn.getDataAt(1))); } } - System.arraycopy(bin.getInternalStore(), 0, bData, 0, n * p); - int[] ipiv = new int[n]; // work on a copy of A - double[] avals = new double[n * n]; + RDoubleVector aDouble; if (a instanceof RAbstractDoubleVector) { - System.arraycopy(a.getInternalStore(), 0, avals, 0, n * n); + aDouble = ((RAbstractDoubleVector) a).materialize(); + } else { + aDouble = (RDoubleVector) castDouble.doCast(a); + } + double[] avals; + if (aDouble.isShared()) { + avals = aCache.get(aDouble.getLength()); + System.arraycopy(aDouble.getDataWithoutCopying(), 0, avals, 0, n * p); } else { - RDoubleVector aDouble = (RDoubleVector) castDouble.doCast(a); - assert aDouble != a; - avals = aDouble.getInternalStore(); + avals = aDouble.getDataWithoutCopying(); } int info = dgesvNode.execute(n, p, avals, n, ipiv, bData, n); if (info < 0) { -- GitLab