From 0bb19d9f899b6553f2ed4689d5373f7cde626274 Mon Sep 17 00:00:00 2001
From: Florian Angerer <florian.angerer@oracle.com>
Date: Wed, 27 Sep 2017 14:32:22 +0200
Subject: [PATCH] Implemented builting 'La_svd'.

---
 .../r/ffi/impl/llvm/TruffleLLVM_Lapack.java   | 18 +++++
 .../ffi/impl/managed/Managed_LapackRFFI.java  |  5 ++
 .../r/ffi/impl/nfi/NativeFunction.java        |  1 +
 .../r/ffi/impl/nfi/TruffleNFI_Lapack.java     | 17 ++++
 .../fficall/src/truffle_common/lapack_rffi.c  |  8 ++
 .../r/nodes/builtin/base/BasePackage.java     |  1 +
 .../r/nodes/builtin/base/LaFunctions.java     | 81 +++++++++++++++++++
 .../truffle/r/runtime/ffi/LapackRFFI.java     | 13 +++
 .../r/test/builtins/TestBuiltin_svd.java      | 25 ++++++
 9 files changed, 169 insertions(+)
 create mode 100644 com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java

diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java
index 7addd0fd6e..8b3d76325b 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java
@@ -205,6 +205,19 @@ public class TruffleLLVM_Lapack implements LapackRFFI {
         }
     }
 
+    private static final class TruffleLLVM_DgesddNode extends TruffleLLVM_DownCallNode implements DgesddNode {
+
+        @Override
+        protected NativeFunction getFunction() {
+            return NativeFunction.dgesdd;
+        }
+
+        @Override
+        public int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork) {
+            return (int) call(jobz, m, n, a, lda, s, u, ldu, vt, ldtv, work, lwork, iwork);
+        }
+    }
+
     private static final class TruffleLLVM_DlangeNode extends TruffleLLVM_DownCallNode implements DlangeNode {
 
         @Override
@@ -295,6 +308,11 @@ public class TruffleLLVM_Lapack implements LapackRFFI {
         return new TruffleLLVM_DgesvNode();
     }
 
+    @Override
+    public DgesddNode createDgesddNode() {
+        return new TruffleLLVM_DgesddNode();
+    }
+
     @Override
     public DlangeNode createDlangeNode() {
         return new TruffleLLVM_DlangeNode();
diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java
index d6254bbab3..7ed0ec3855 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java
@@ -77,6 +77,11 @@ public class Managed_LapackRFFI implements LapackRFFI {
         throw unsupported("lapack");
     }
 
+    @Override
+    public DgesddNode createDgesddNode() {
+        throw unsupported("lapack");
+    }
+
     @Override
     public DlangeNode createDlangeNode() {
         throw unsupported("lapack");
diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java
index 44c9c77937..11c56d1c53 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java
@@ -68,6 +68,7 @@ public enum NativeFunction {
     dpotri("(uint8, sint32, [double], sint32) : sint32", "call_lapack_"),
     dpstrf("uint8, sint32, [double], sint32, [sint32], [sint32], double, [double]) : sint32", "call_lapack_"),
     dgesv("(sint32, sint32, [double], sint32, [sint32], [double], sint32) : sint32", "call_lapack_"),
+    dgesdd("(uint8, sint32, sint32, [double], sint32, [double], [double], sint32, [double], sint32, [double], sint32, [sint32]) : sint32", "call_lapack_"),
     dlange("(uint8, sint32, sint32, [double], sint32, [double]) : double", "call_lapack_"),
     dgecon("(uint8, sint32, [double], sint32, double, [double], [double], [sint32]) : sint32", "call_lapack_"),
     dsyevr(
diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java
index a69c62dcb4..790fbab4ae 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java
@@ -146,6 +146,18 @@ public class TruffleNFI_Lapack implements LapackRFFI {
         }
     }
 
+    private static class TruffleNFI_DgesddNode extends TruffleNFI_DownCallNode implements DgesddNode {
+        @Override
+        protected NativeFunction getFunction() {
+            return NativeFunction.dgesdd;
+        }
+
+        @Override
+        public int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork) {
+            return (int) call(jobz, m, n, a, lda, s, u, ldu, vt, ldtv, work, lwork, iwork);
+        }
+    }
+
     private static class TruffleNFI_DlangeNode extends TruffleNFI_DownCallNode implements DlangeNode {
         @Override
         protected NativeFunction getFunction() {
@@ -233,6 +245,11 @@ public class TruffleNFI_Lapack implements LapackRFFI {
         return new TruffleNFI_DgesvNode();
     }
 
+    @Override
+    public DgesddNode createDgesddNode() {
+        return new TruffleNFI_DgesddNode();
+    }
+
     @Override
     public DlangeNode createDlangeNode() {
         return new TruffleNFI_DlangeNode();
diff --git a/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c b/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c
index ad968f4e0e..5e7a0e2f16 100644
--- a/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c
+++ b/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c
@@ -108,6 +108,14 @@ int call_lapack_dgesv(int n, int nrhs, double *a, int lda, int *ipiv, double *b,
     return info;
 }
 
+extern int dgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldtv, double *work, int *lwork, int *iwork, int *info);
+
+int call_lapack_dgesdd(char jobz, int m, int n, double *a, int lda, double *s, double *u, int ldu, double *vt, int ldtv, double *work, int lwork, int *iwork) {
+    int info;
+    dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldtv, work, &lwork, iwork, &info);
+    return info;
+}
+
 extern double dlange_(char *norm, int *m, int *n, double *a, int *lda, double *work);
 
 double call_lapack_dlange(char norm, int m, int n, double *a, int lda, double *work) {
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java
index 52643afaf9..cd05c3e209 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java
@@ -569,6 +569,7 @@ public class BasePackage extends RBuiltinPackage {
         add(LaFunctions.Rs.class, LaFunctionsFactory.RsNodeGen::create);
         add(LaFunctions.Version.class, LaFunctionsFactory.VersionNodeGen::create);
         add(LaFunctions.LaSolve.class, LaFunctionsFactory.LaSolveNodeGen::create);
+        add(LaFunctions.Svd.class, LaFunctionsFactory.SvdNodeGen::create);
         add(Lapply.class, LapplyNodeGen::create);
         add(Length.class, LengthNodeGen::create);
         add(Lengths.class, LengthsNodeGen::create);
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java
index cca501bce3..0df90a68fc 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java
@@ -13,6 +13,7 @@ package com.oracle.truffle.r.nodes.builtin.base;
 
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimEq;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimGt;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.emptyDoubleVector;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gt;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf;
@@ -24,6 +25,7 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.or;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.squareMatrix;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.READS_STATE;
@@ -38,6 +40,7 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.profiles.ConditionProfile;
+import com.oracle.truffle.r.nodes.attributes.CopyAttributesNode;
 import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode;
 import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
 import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
@@ -62,6 +65,7 @@ import com.oracle.truffle.r.runtime.data.RStringVector;
 import com.oracle.truffle.r.runtime.data.RVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
+import com.oracle.truffle.r.runtime.data.nodes.GetDataCopy;
 import com.oracle.truffle.r.runtime.data.nodes.GetReadonlyData;
 import com.oracle.truffle.r.runtime.ffi.LapackRFFI;
 import com.oracle.truffle.r.runtime.ffi.RFFIFactory;
@@ -730,4 +734,81 @@ public class LaFunctions {
             return b;
         }
     }
+
+    @RBuiltin(name = "La_svd", kind = INTERNAL, parameterNames = {"jobu", "x", "s", "u", "vt"}, behavior = PURE)
+    public abstract static class Svd extends RBuiltinNode.Arg5 {
+
+        static {
+            Casts casts = new Casts(Svd.class);
+            casts.arg("jobu").defaultError(Message.MUST_BE_STRING, "jobu").mustNotBeNull().mustBe(stringValue()).asStringVector().findFirst();
+            casts.arg("x").mustNotBeNull().mustBe(doubleValue()).asDoubleVectorClosure(true, true, true);
+            casts.arg("s").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true);
+            casts.arg("u").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true);
+            casts.arg("vt").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true);
+        }
+
+        @Child private LapackRFFI.DgesddNode dgesddNode = LapackRFFI.DgesddNode.create();
+
+        @Specialization
+        protected Object doSvd(String ju, RAbstractDoubleVector x, RAbstractDoubleVector s, RAbstractDoubleVector u, RAbstractDoubleVector vt,
+                        @Cached("create()") GetDataCopy.Double getDataCopyNode,
+                        @Cached("createCopyAllAttributes()") CopyAttributesNode copyAttrNode,
+                        @Cached("create()") GetDimAttributeNode getDimsNode) {
+
+            int[] xdims = getDimsNode.getDimensions(x);
+            int n = xdims[0];
+            int p = xdims[1];
+
+            int[] udims = getDimsNode.getDimensions(u);
+            int ldu = udims[0];
+
+            int[] vtdims = getDimsNode.getDimensions(vt);
+            int ldvt = vtdims[0];
+
+            int[] iwork = new int[8 * Math.min(n, p)];
+
+            RDoubleVector xMaterialized = x.materialize();
+            RDoubleVector sMaterialized = s.materialize();
+            RDoubleVector uMaterialized = u.materialize();
+            RDoubleVector vtMaterialized = vt.materialize();
+
+            double[] xvals = getDataCopyNode.execute(xMaterialized);
+            double[] sdata = getDataCopyNode.execute(sMaterialized);
+            double[] udata = getDataCopyNode.execute(uMaterialized);
+            double[] vtdata = getDataCopyNode.execute(vtMaterialized);
+            double[] tmp = new double[1];
+
+// F77_CALL(dgesdd)(ju, &n, &p, xvals, &n, REAL(s), REAL(u), &ldu, REAL(vt), &ldvt, &tmp, &lwork,
+// iwork, &info);
+            int info = dgesddNode.execute(ju.charAt(0), n, p, xvals, n, sdata, udata, ldu, vtdata, ldvt, tmp, -1, iwork);
+            if (info != 0) {
+                error(Message.LAPACK_ERROR, info, "dgesdd");
+            }
+
+            int lwork = (int) tmp[0];
+            double[] work = new double[lwork];
+// F77_CALL(dgesdd)(ju, &n, &p, xvals, &n, REAL(s), REAL(u), &ldu, REAL(vt), &ldvt, work, &lwork,
+// iwork, &info);
+            dgesddNode.execute(ju.charAt(0), n, p, xvals, n, sdata, udata, ldu, vtdata, ldvt, work, lwork, iwork);
+            if (info != 0) {
+                error(Message.LAPACK_ERROR, info, "dgesdd");
+            }
+
+            RStringVector nm = RDataFactory.createStringVector(new String[]{"d", "u", "vt"}, true);
+            Object[] val = new Object[3];
+            RDoubleVector sResult = RDataFactory.createDoubleVector(sdata, false);
+            RDoubleVector uResult = RDataFactory.createDoubleVector(udata, false);
+            RDoubleVector vtResult = RDataFactory.createDoubleVector(vtdata, false);
+
+            copyAttrNode.execute(sResult, sResult, sResult.getLength(), sMaterialized, sMaterialized.getLength());
+            copyAttrNode.execute(uResult, uResult, uResult.getLength(), uMaterialized, uMaterialized.getLength());
+            copyAttrNode.execute(vtResult, vtResult, vtResult.getLength(), vtMaterialized, vtMaterialized.getLength());
+
+            val[0] = sResult;
+            val[1] = uResult;
+            val[2] = vtResult;
+
+            return RDataFactory.createList(val, nm);
+        }
+    }
 }
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java
index d995028b42..92f000c636 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java
@@ -144,6 +144,17 @@ public interface LapackRFFI {
         }
     }
 
+    interface DgesddNode extends NodeInterface {
+        /**
+         * See <a href="http://www.netlib.org/lapack/explore-html/db/db4/dgesdd_8f.html">spec</a>.
+         */
+        int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork);
+
+        static DgesddNode create() {
+            return RFFIFactory.getLapackRFFI().createDgesddNode();
+        }
+    }
+
     interface DlangeNode extends NodeInterface {
 
         /**
@@ -198,6 +209,8 @@ public interface LapackRFFI {
 
     DgesvNode createDgesvNode();
 
+    DgesddNode createDgesddNode();
+
     DlangeNode createDlangeNode();
 
     DgeconNode createDgeconNode();
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java
new file mode 100644
index 0000000000..aaf4adffe9
--- /dev/null
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java
@@ -0,0 +1,25 @@
+/*
+ * This material is distributed under the GNU General Public License
+ * Version 2. You may review the terms of this license at
+ * http://www.gnu.org/licenses/gpl-2.0.html
+ *
+ * Copyright (c) 2014, Purdue University
+ * Copyright (c) 2014, 2016, Oracle and/or its affiliates
+ *
+ * All rights reserved.
+ */
+package com.oracle.truffle.r.test.builtins;
+
+import org.junit.Test;
+
+import com.oracle.truffle.r.test.TestBase;
+
+// Checkstyle: stop line length check
+
+public class TestBuiltin_svd extends TestBase {
+
+    @Test
+    public void testSvd() {
+        assertEval("");
+    }
+}
-- 
GitLab