From 0bb19d9f899b6553f2ed4689d5373f7cde626274 Mon Sep 17 00:00:00 2001 From: Florian Angerer <florian.angerer@oracle.com> Date: Wed, 27 Sep 2017 14:32:22 +0200 Subject: [PATCH] Implemented builting 'La_svd'. --- .../r/ffi/impl/llvm/TruffleLLVM_Lapack.java | 18 +++++ .../ffi/impl/managed/Managed_LapackRFFI.java | 5 ++ .../r/ffi/impl/nfi/NativeFunction.java | 1 + .../r/ffi/impl/nfi/TruffleNFI_Lapack.java | 17 ++++ .../fficall/src/truffle_common/lapack_rffi.c | 8 ++ .../r/nodes/builtin/base/BasePackage.java | 1 + .../r/nodes/builtin/base/LaFunctions.java | 81 +++++++++++++++++++ .../truffle/r/runtime/ffi/LapackRFFI.java | 13 +++ .../r/test/builtins/TestBuiltin_svd.java | 25 ++++++ 9 files changed, 169 insertions(+) create mode 100644 com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java index 7addd0fd6e..8b3d76325b 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_Lapack.java @@ -205,6 +205,19 @@ public class TruffleLLVM_Lapack implements LapackRFFI { } } + private static final class TruffleLLVM_DgesddNode extends TruffleLLVM_DownCallNode implements DgesddNode { + + @Override + protected NativeFunction getFunction() { + return NativeFunction.dgesdd; + } + + @Override + public int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork) { + return (int) call(jobz, m, n, a, lda, s, u, ldu, vt, ldtv, work, lwork, iwork); + } + } + private static final class TruffleLLVM_DlangeNode extends TruffleLLVM_DownCallNode implements DlangeNode { @Override @@ -295,6 +308,11 @@ public class TruffleLLVM_Lapack implements LapackRFFI { return new TruffleLLVM_DgesvNode(); } + @Override + public DgesddNode createDgesddNode() { + return new TruffleLLVM_DgesddNode(); + } + @Override public DlangeNode createDlangeNode() { return new TruffleLLVM_DlangeNode(); diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java index d6254bbab3..7ed0ec3855 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/managed/Managed_LapackRFFI.java @@ -77,6 +77,11 @@ public class Managed_LapackRFFI implements LapackRFFI { throw unsupported("lapack"); } + @Override + public DgesddNode createDgesddNode() { + throw unsupported("lapack"); + } + @Override public DlangeNode createDlangeNode() { throw unsupported("lapack"); diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java index 44c9c77937..11c56d1c53 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/NativeFunction.java @@ -68,6 +68,7 @@ public enum NativeFunction { dpotri("(uint8, sint32, [double], sint32) : sint32", "call_lapack_"), dpstrf("uint8, sint32, [double], sint32, [sint32], [sint32], double, [double]) : sint32", "call_lapack_"), dgesv("(sint32, sint32, [double], sint32, [sint32], [double], sint32) : sint32", "call_lapack_"), + dgesdd("(uint8, sint32, sint32, [double], sint32, [double], [double], sint32, [double], sint32, [double], sint32, [sint32]) : sint32", "call_lapack_"), dlange("(uint8, sint32, sint32, [double], sint32, [double]) : double", "call_lapack_"), dgecon("(uint8, sint32, [double], sint32, double, [double], [double], [sint32]) : sint32", "call_lapack_"), dsyevr( diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java index a69c62dcb4..790fbab4ae 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Lapack.java @@ -146,6 +146,18 @@ public class TruffleNFI_Lapack implements LapackRFFI { } } + private static class TruffleNFI_DgesddNode extends TruffleNFI_DownCallNode implements DgesddNode { + @Override + protected NativeFunction getFunction() { + return NativeFunction.dgesdd; + } + + @Override + public int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork) { + return (int) call(jobz, m, n, a, lda, s, u, ldu, vt, ldtv, work, lwork, iwork); + } + } + private static class TruffleNFI_DlangeNode extends TruffleNFI_DownCallNode implements DlangeNode { @Override protected NativeFunction getFunction() { @@ -233,6 +245,11 @@ public class TruffleNFI_Lapack implements LapackRFFI { return new TruffleNFI_DgesvNode(); } + @Override + public DgesddNode createDgesddNode() { + return new TruffleNFI_DgesddNode(); + } + @Override public DlangeNode createDlangeNode() { return new TruffleNFI_DlangeNode(); diff --git a/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c b/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c index ad968f4e0e..5e7a0e2f16 100644 --- a/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c +++ b/com.oracle.truffle.r.native/fficall/src/truffle_common/lapack_rffi.c @@ -108,6 +108,14 @@ int call_lapack_dgesv(int n, int nrhs, double *a, int lda, int *ipiv, double *b, return info; } +extern int dgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldtv, double *work, int *lwork, int *iwork, int *info); + +int call_lapack_dgesdd(char jobz, int m, int n, double *a, int lda, double *s, double *u, int ldu, double *vt, int ldtv, double *work, int lwork, int *iwork) { + int info; + dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldtv, work, &lwork, iwork, &info); + return info; +} + extern double dlange_(char *norm, int *m, int *n, double *a, int *lda, double *work); double call_lapack_dlange(char norm, int m, int n, double *a, int lda, double *work) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java index 52643afaf9..cd05c3e209 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java @@ -569,6 +569,7 @@ public class BasePackage extends RBuiltinPackage { add(LaFunctions.Rs.class, LaFunctionsFactory.RsNodeGen::create); add(LaFunctions.Version.class, LaFunctionsFactory.VersionNodeGen::create); add(LaFunctions.LaSolve.class, LaFunctionsFactory.LaSolveNodeGen::create); + add(LaFunctions.Svd.class, LaFunctionsFactory.SvdNodeGen::create); add(Lapply.class, LapplyNodeGen::create); add(Length.class, LengthNodeGen::create); add(Lengths.class, LengthsNodeGen::create); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java index cca501bce3..0df90a68fc 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java @@ -13,6 +13,7 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimEq; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimGt; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.emptyDoubleVector; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gt; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf; @@ -24,6 +25,7 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.or; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.squareMatrix; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.READS_STATE; @@ -38,6 +40,7 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.attributes.CopyAttributesNode; import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode; @@ -62,6 +65,7 @@ import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.GetDataCopy; import com.oracle.truffle.r.runtime.data.nodes.GetReadonlyData; import com.oracle.truffle.r.runtime.ffi.LapackRFFI; import com.oracle.truffle.r.runtime.ffi.RFFIFactory; @@ -730,4 +734,81 @@ public class LaFunctions { return b; } } + + @RBuiltin(name = "La_svd", kind = INTERNAL, parameterNames = {"jobu", "x", "s", "u", "vt"}, behavior = PURE) + public abstract static class Svd extends RBuiltinNode.Arg5 { + + static { + Casts casts = new Casts(Svd.class); + casts.arg("jobu").defaultError(Message.MUST_BE_STRING, "jobu").mustNotBeNull().mustBe(stringValue()).asStringVector().findFirst(); + casts.arg("x").mustNotBeNull().mustBe(doubleValue()).asDoubleVectorClosure(true, true, true); + casts.arg("s").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true); + casts.arg("u").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true); + casts.arg("vt").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true); + } + + @Child private LapackRFFI.DgesddNode dgesddNode = LapackRFFI.DgesddNode.create(); + + @Specialization + protected Object doSvd(String ju, RAbstractDoubleVector x, RAbstractDoubleVector s, RAbstractDoubleVector u, RAbstractDoubleVector vt, + @Cached("create()") GetDataCopy.Double getDataCopyNode, + @Cached("createCopyAllAttributes()") CopyAttributesNode copyAttrNode, + @Cached("create()") GetDimAttributeNode getDimsNode) { + + int[] xdims = getDimsNode.getDimensions(x); + int n = xdims[0]; + int p = xdims[1]; + + int[] udims = getDimsNode.getDimensions(u); + int ldu = udims[0]; + + int[] vtdims = getDimsNode.getDimensions(vt); + int ldvt = vtdims[0]; + + int[] iwork = new int[8 * Math.min(n, p)]; + + RDoubleVector xMaterialized = x.materialize(); + RDoubleVector sMaterialized = s.materialize(); + RDoubleVector uMaterialized = u.materialize(); + RDoubleVector vtMaterialized = vt.materialize(); + + double[] xvals = getDataCopyNode.execute(xMaterialized); + double[] sdata = getDataCopyNode.execute(sMaterialized); + double[] udata = getDataCopyNode.execute(uMaterialized); + double[] vtdata = getDataCopyNode.execute(vtMaterialized); + double[] tmp = new double[1]; + +// F77_CALL(dgesdd)(ju, &n, &p, xvals, &n, REAL(s), REAL(u), &ldu, REAL(vt), &ldvt, &tmp, &lwork, +// iwork, &info); + int info = dgesddNode.execute(ju.charAt(0), n, p, xvals, n, sdata, udata, ldu, vtdata, ldvt, tmp, -1, iwork); + if (info != 0) { + error(Message.LAPACK_ERROR, info, "dgesdd"); + } + + int lwork = (int) tmp[0]; + double[] work = new double[lwork]; +// F77_CALL(dgesdd)(ju, &n, &p, xvals, &n, REAL(s), REAL(u), &ldu, REAL(vt), &ldvt, work, &lwork, +// iwork, &info); + dgesddNode.execute(ju.charAt(0), n, p, xvals, n, sdata, udata, ldu, vtdata, ldvt, work, lwork, iwork); + if (info != 0) { + error(Message.LAPACK_ERROR, info, "dgesdd"); + } + + RStringVector nm = RDataFactory.createStringVector(new String[]{"d", "u", "vt"}, true); + Object[] val = new Object[3]; + RDoubleVector sResult = RDataFactory.createDoubleVector(sdata, false); + RDoubleVector uResult = RDataFactory.createDoubleVector(udata, false); + RDoubleVector vtResult = RDataFactory.createDoubleVector(vtdata, false); + + copyAttrNode.execute(sResult, sResult, sResult.getLength(), sMaterialized, sMaterialized.getLength()); + copyAttrNode.execute(uResult, uResult, uResult.getLength(), uMaterialized, uMaterialized.getLength()); + copyAttrNode.execute(vtResult, vtResult, vtResult.getLength(), vtMaterialized, vtMaterialized.getLength()); + + val[0] = sResult; + val[1] = uResult; + val[2] = vtResult; + + return RDataFactory.createList(val, nm); + } + } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java index d995028b42..92f000c636 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/LapackRFFI.java @@ -144,6 +144,17 @@ public interface LapackRFFI { } } + interface DgesddNode extends NodeInterface { + /** + * See <a href="http://www.netlib.org/lapack/explore-html/db/db4/dgesdd_8f.html">spec</a>. + */ + int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork); + + static DgesddNode create() { + return RFFIFactory.getLapackRFFI().createDgesddNode(); + } + } + interface DlangeNode extends NodeInterface { /** @@ -198,6 +209,8 @@ public interface LapackRFFI { DgesvNode createDgesvNode(); + DgesddNode createDgesddNode(); + DlangeNode createDlangeNode(); DgeconNode createDgeconNode(); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java new file mode 100644 index 0000000000..aaf4adffe9 --- /dev/null +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_svd.java @@ -0,0 +1,25 @@ +/* + * This material is distributed under the GNU General Public License + * Version 2. You may review the terms of this license at + * http://www.gnu.org/licenses/gpl-2.0.html + * + * Copyright (c) 2014, Purdue University + * Copyright (c) 2014, 2016, Oracle and/or its affiliates + * + * All rights reserved. + */ +package com.oracle.truffle.r.test.builtins; + +import org.junit.Test; + +import com.oracle.truffle.r.test.TestBase; + +// Checkstyle: stop line length check + +public class TestBuiltin_svd extends TestBase { + + @Test + public void testSvd() { + assertEval(""); + } +} -- GitLab