Skip to content
Snippets Groups Projects
Commit 0bb19d9f authored by Florian Angerer's avatar Florian Angerer
Browse files

Implemented builting 'La_svd'.

parent f610487f
No related branches found
No related tags found
No related merge requests found
Showing
with 169 additions and 0 deletions
...@@ -205,6 +205,19 @@ public class TruffleLLVM_Lapack implements LapackRFFI { ...@@ -205,6 +205,19 @@ public class TruffleLLVM_Lapack implements LapackRFFI {
} }
} }
private static final class TruffleLLVM_DgesddNode extends TruffleLLVM_DownCallNode implements DgesddNode {
@Override
protected NativeFunction getFunction() {
return NativeFunction.dgesdd;
}
@Override
public int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork) {
return (int) call(jobz, m, n, a, lda, s, u, ldu, vt, ldtv, work, lwork, iwork);
}
}
private static final class TruffleLLVM_DlangeNode extends TruffleLLVM_DownCallNode implements DlangeNode { private static final class TruffleLLVM_DlangeNode extends TruffleLLVM_DownCallNode implements DlangeNode {
@Override @Override
...@@ -295,6 +308,11 @@ public class TruffleLLVM_Lapack implements LapackRFFI { ...@@ -295,6 +308,11 @@ public class TruffleLLVM_Lapack implements LapackRFFI {
return new TruffleLLVM_DgesvNode(); return new TruffleLLVM_DgesvNode();
} }
@Override
public DgesddNode createDgesddNode() {
return new TruffleLLVM_DgesddNode();
}
@Override @Override
public DlangeNode createDlangeNode() { public DlangeNode createDlangeNode() {
return new TruffleLLVM_DlangeNode(); return new TruffleLLVM_DlangeNode();
......
...@@ -77,6 +77,11 @@ public class Managed_LapackRFFI implements LapackRFFI { ...@@ -77,6 +77,11 @@ public class Managed_LapackRFFI implements LapackRFFI {
throw unsupported("lapack"); throw unsupported("lapack");
} }
@Override
public DgesddNode createDgesddNode() {
throw unsupported("lapack");
}
@Override @Override
public DlangeNode createDlangeNode() { public DlangeNode createDlangeNode() {
throw unsupported("lapack"); throw unsupported("lapack");
......
...@@ -68,6 +68,7 @@ public enum NativeFunction { ...@@ -68,6 +68,7 @@ public enum NativeFunction {
dpotri("(uint8, sint32, [double], sint32) : sint32", "call_lapack_"), dpotri("(uint8, sint32, [double], sint32) : sint32", "call_lapack_"),
dpstrf("uint8, sint32, [double], sint32, [sint32], [sint32], double, [double]) : sint32", "call_lapack_"), dpstrf("uint8, sint32, [double], sint32, [sint32], [sint32], double, [double]) : sint32", "call_lapack_"),
dgesv("(sint32, sint32, [double], sint32, [sint32], [double], sint32) : sint32", "call_lapack_"), dgesv("(sint32, sint32, [double], sint32, [sint32], [double], sint32) : sint32", "call_lapack_"),
dgesdd("(uint8, sint32, sint32, [double], sint32, [double], [double], sint32, [double], sint32, [double], sint32, [sint32]) : sint32", "call_lapack_"),
dlange("(uint8, sint32, sint32, [double], sint32, [double]) : double", "call_lapack_"), dlange("(uint8, sint32, sint32, [double], sint32, [double]) : double", "call_lapack_"),
dgecon("(uint8, sint32, [double], sint32, double, [double], [double], [sint32]) : sint32", "call_lapack_"), dgecon("(uint8, sint32, [double], sint32, double, [double], [double], [sint32]) : sint32", "call_lapack_"),
dsyevr( dsyevr(
......
...@@ -146,6 +146,18 @@ public class TruffleNFI_Lapack implements LapackRFFI { ...@@ -146,6 +146,18 @@ public class TruffleNFI_Lapack implements LapackRFFI {
} }
} }
private static class TruffleNFI_DgesddNode extends TruffleNFI_DownCallNode implements DgesddNode {
@Override
protected NativeFunction getFunction() {
return NativeFunction.dgesdd;
}
@Override
public int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork) {
return (int) call(jobz, m, n, a, lda, s, u, ldu, vt, ldtv, work, lwork, iwork);
}
}
private static class TruffleNFI_DlangeNode extends TruffleNFI_DownCallNode implements DlangeNode { private static class TruffleNFI_DlangeNode extends TruffleNFI_DownCallNode implements DlangeNode {
@Override @Override
protected NativeFunction getFunction() { protected NativeFunction getFunction() {
...@@ -233,6 +245,11 @@ public class TruffleNFI_Lapack implements LapackRFFI { ...@@ -233,6 +245,11 @@ public class TruffleNFI_Lapack implements LapackRFFI {
return new TruffleNFI_DgesvNode(); return new TruffleNFI_DgesvNode();
} }
@Override
public DgesddNode createDgesddNode() {
return new TruffleNFI_DgesddNode();
}
@Override @Override
public DlangeNode createDlangeNode() { public DlangeNode createDlangeNode() {
return new TruffleNFI_DlangeNode(); return new TruffleNFI_DlangeNode();
......
...@@ -108,6 +108,14 @@ int call_lapack_dgesv(int n, int nrhs, double *a, int lda, int *ipiv, double *b, ...@@ -108,6 +108,14 @@ int call_lapack_dgesv(int n, int nrhs, double *a, int lda, int *ipiv, double *b,
return info; return info;
} }
extern int dgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldtv, double *work, int *lwork, int *iwork, int *info);
int call_lapack_dgesdd(char jobz, int m, int n, double *a, int lda, double *s, double *u, int ldu, double *vt, int ldtv, double *work, int lwork, int *iwork) {
int info;
dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldtv, work, &lwork, iwork, &info);
return info;
}
extern double dlange_(char *norm, int *m, int *n, double *a, int *lda, double *work); extern double dlange_(char *norm, int *m, int *n, double *a, int *lda, double *work);
double call_lapack_dlange(char norm, int m, int n, double *a, int lda, double *work) { double call_lapack_dlange(char norm, int m, int n, double *a, int lda, double *work) {
......
...@@ -569,6 +569,7 @@ public class BasePackage extends RBuiltinPackage { ...@@ -569,6 +569,7 @@ public class BasePackage extends RBuiltinPackage {
add(LaFunctions.Rs.class, LaFunctionsFactory.RsNodeGen::create); add(LaFunctions.Rs.class, LaFunctionsFactory.RsNodeGen::create);
add(LaFunctions.Version.class, LaFunctionsFactory.VersionNodeGen::create); add(LaFunctions.Version.class, LaFunctionsFactory.VersionNodeGen::create);
add(LaFunctions.LaSolve.class, LaFunctionsFactory.LaSolveNodeGen::create); add(LaFunctions.LaSolve.class, LaFunctionsFactory.LaSolveNodeGen::create);
add(LaFunctions.Svd.class, LaFunctionsFactory.SvdNodeGen::create);
add(Lapply.class, LapplyNodeGen::create); add(Lapply.class, LapplyNodeGen::create);
add(Length.class, LengthNodeGen::create); add(Length.class, LengthNodeGen::create);
add(Lengths.class, LengthsNodeGen::create); add(Lengths.class, LengthsNodeGen::create);
......
...@@ -13,6 +13,7 @@ package com.oracle.truffle.r.nodes.builtin.base; ...@@ -13,6 +13,7 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimEq; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimEq;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimGt; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimGt;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.emptyDoubleVector; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.emptyDoubleVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gt; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gt;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf;
...@@ -24,6 +25,7 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; ...@@ -24,6 +25,7 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.or; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.or;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.squareMatrix; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.squareMatrix;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.READS_STATE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.READS_STATE;
...@@ -38,6 +40,7 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; ...@@ -38,6 +40,7 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.attributes.CopyAttributesNode;
import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
...@@ -62,6 +65,7 @@ import com.oracle.truffle.r.runtime.data.RStringVector; ...@@ -62,6 +65,7 @@ import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.GetDataCopy;
import com.oracle.truffle.r.runtime.data.nodes.GetReadonlyData; import com.oracle.truffle.r.runtime.data.nodes.GetReadonlyData;
import com.oracle.truffle.r.runtime.ffi.LapackRFFI; import com.oracle.truffle.r.runtime.ffi.LapackRFFI;
import com.oracle.truffle.r.runtime.ffi.RFFIFactory; import com.oracle.truffle.r.runtime.ffi.RFFIFactory;
...@@ -730,4 +734,81 @@ public class LaFunctions { ...@@ -730,4 +734,81 @@ public class LaFunctions {
return b; return b;
} }
} }
@RBuiltin(name = "La_svd", kind = INTERNAL, parameterNames = {"jobu", "x", "s", "u", "vt"}, behavior = PURE)
public abstract static class Svd extends RBuiltinNode.Arg5 {
static {
Casts casts = new Casts(Svd.class);
casts.arg("jobu").defaultError(Message.MUST_BE_STRING, "jobu").mustNotBeNull().mustBe(stringValue()).asStringVector().findFirst();
casts.arg("x").mustNotBeNull().mustBe(doubleValue()).asDoubleVectorClosure(true, true, true);
casts.arg("s").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true);
casts.arg("u").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true);
casts.arg("vt").mustNotBeNull().mustBe(doubleValue()).asDoubleVector(true, true, true);
}
@Child private LapackRFFI.DgesddNode dgesddNode = LapackRFFI.DgesddNode.create();
@Specialization
protected Object doSvd(String ju, RAbstractDoubleVector x, RAbstractDoubleVector s, RAbstractDoubleVector u, RAbstractDoubleVector vt,
@Cached("create()") GetDataCopy.Double getDataCopyNode,
@Cached("createCopyAllAttributes()") CopyAttributesNode copyAttrNode,
@Cached("create()") GetDimAttributeNode getDimsNode) {
int[] xdims = getDimsNode.getDimensions(x);
int n = xdims[0];
int p = xdims[1];
int[] udims = getDimsNode.getDimensions(u);
int ldu = udims[0];
int[] vtdims = getDimsNode.getDimensions(vt);
int ldvt = vtdims[0];
int[] iwork = new int[8 * Math.min(n, p)];
RDoubleVector xMaterialized = x.materialize();
RDoubleVector sMaterialized = s.materialize();
RDoubleVector uMaterialized = u.materialize();
RDoubleVector vtMaterialized = vt.materialize();
double[] xvals = getDataCopyNode.execute(xMaterialized);
double[] sdata = getDataCopyNode.execute(sMaterialized);
double[] udata = getDataCopyNode.execute(uMaterialized);
double[] vtdata = getDataCopyNode.execute(vtMaterialized);
double[] tmp = new double[1];
// F77_CALL(dgesdd)(ju, &n, &p, xvals, &n, REAL(s), REAL(u), &ldu, REAL(vt), &ldvt, &tmp, &lwork,
// iwork, &info);
int info = dgesddNode.execute(ju.charAt(0), n, p, xvals, n, sdata, udata, ldu, vtdata, ldvt, tmp, -1, iwork);
if (info != 0) {
error(Message.LAPACK_ERROR, info, "dgesdd");
}
int lwork = (int) tmp[0];
double[] work = new double[lwork];
// F77_CALL(dgesdd)(ju, &n, &p, xvals, &n, REAL(s), REAL(u), &ldu, REAL(vt), &ldvt, work, &lwork,
// iwork, &info);
dgesddNode.execute(ju.charAt(0), n, p, xvals, n, sdata, udata, ldu, vtdata, ldvt, work, lwork, iwork);
if (info != 0) {
error(Message.LAPACK_ERROR, info, "dgesdd");
}
RStringVector nm = RDataFactory.createStringVector(new String[]{"d", "u", "vt"}, true);
Object[] val = new Object[3];
RDoubleVector sResult = RDataFactory.createDoubleVector(sdata, false);
RDoubleVector uResult = RDataFactory.createDoubleVector(udata, false);
RDoubleVector vtResult = RDataFactory.createDoubleVector(vtdata, false);
copyAttrNode.execute(sResult, sResult, sResult.getLength(), sMaterialized, sMaterialized.getLength());
copyAttrNode.execute(uResult, uResult, uResult.getLength(), uMaterialized, uMaterialized.getLength());
copyAttrNode.execute(vtResult, vtResult, vtResult.getLength(), vtMaterialized, vtMaterialized.getLength());
val[0] = sResult;
val[1] = uResult;
val[2] = vtResult;
return RDataFactory.createList(val, nm);
}
}
} }
...@@ -144,6 +144,17 @@ public interface LapackRFFI { ...@@ -144,6 +144,17 @@ public interface LapackRFFI {
} }
} }
interface DgesddNode extends NodeInterface {
/**
* See <a href="http://www.netlib.org/lapack/explore-html/db/db4/dgesdd_8f.html">spec</a>.
*/
int execute(char jobz, int m, int n, double[] a, int lda, double[] s, double[] u, int ldu, double[] vt, int ldtv, double[] work, int lwork, int[] iwork);
static DgesddNode create() {
return RFFIFactory.getLapackRFFI().createDgesddNode();
}
}
interface DlangeNode extends NodeInterface { interface DlangeNode extends NodeInterface {
/** /**
...@@ -198,6 +209,8 @@ public interface LapackRFFI { ...@@ -198,6 +209,8 @@ public interface LapackRFFI {
DgesvNode createDgesvNode(); DgesvNode createDgesvNode();
DgesddNode createDgesddNode();
DlangeNode createDlangeNode(); DlangeNode createDlangeNode();
DgeconNode createDgeconNode(); DgeconNode createDgeconNode();
......
/*
* This material is distributed under the GNU General Public License
* Version 2. You may review the terms of this license at
* http://www.gnu.org/licenses/gpl-2.0.html
*
* Copyright (c) 2014, Purdue University
* Copyright (c) 2014, 2016, Oracle and/or its affiliates
*
* All rights reserved.
*/
package com.oracle.truffle.r.test.builtins;
import org.junit.Test;
import com.oracle.truffle.r.test.TestBase;
// Checkstyle: stop line length check
public class TestBuiltin_svd extends TestBase {
@Test
public void testSvd() {
assertEval("");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment