diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_DownCallNodeFactory.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_DownCallNodeFactory.java index 2852cded38785fd95670e9fa2ccfa8b3c48d6d8b..37dee1e1a4d6d8aaf114f6821a2840f86027090c 100644 --- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_DownCallNodeFactory.java +++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/llvm/TruffleLLVM_DownCallNodeFactory.java @@ -31,6 +31,7 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.data.RFunction; import com.oracle.truffle.r.runtime.ffi.DLL; import com.oracle.truffle.r.runtime.ffi.DLL.DLLInfo; import com.oracle.truffle.r.runtime.ffi.DLL.SymbolHandle; @@ -70,6 +71,8 @@ final class TruffleLLVM_DownCallNodeFactory extends DownCallNodeFactory { @Override @ExplodeLoop protected long beforeCall(NativeFunction nativeFunction, TruffleObject fn, Object[] args) { + assert !(fn instanceof RFunction); + for (int i = 0; i < args.length; i++) { Object obj = args[i]; if (obj instanceof double[]) { @@ -93,6 +96,8 @@ final class TruffleLLVM_DownCallNodeFactory extends DownCallNodeFactory { @Override @ExplodeLoop protected void afterCall(long before, NativeFunction fn, TruffleObject target, Object[] args) { + assert !(target instanceof RFunction); + for (int i = 0; i < args.length; i++) { Object obj = args[i]; if (obj instanceof NativeArray<?>) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java index 601f97464f58dc91e5b9ab586645b47b7795b1ac..f4147ee060e9e18ad5a43a21f381ddf23955f4c7 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java @@ -122,6 +122,8 @@ import com.oracle.truffle.r.nodes.builtin.fastr.FastRPkgSource; import com.oracle.truffle.r.nodes.builtin.fastr.FastRPkgSourceNodeGen; import com.oracle.truffle.r.nodes.builtin.fastr.FastRRefCountInfo; import com.oracle.truffle.r.nodes.builtin.fastr.FastRRefCountInfoNodeGen; +import com.oracle.truffle.r.nodes.builtin.fastr.FastRRegisterFunctions; +import com.oracle.truffle.r.nodes.builtin.fastr.FastRRegisterFunctionsNodeGen; import com.oracle.truffle.r.nodes.builtin.fastr.FastRSlotAssign; import com.oracle.truffle.r.nodes.builtin.fastr.FastRSlotAssignNodeGen; import com.oracle.truffle.r.nodes.builtin.fastr.FastRSourceInfo; @@ -437,6 +439,7 @@ public class BasePackage extends RBuiltinPackage { add(FastRContext.ChannelSend.class, FastRContextFactory.ChannelSendNodeGen::create); add(FastRContext.Spawn.class, FastRContextFactory.SpawnNodeGen::create); add(FastRContext.Join.class, FastRContextFactory.JoinNodeGen::create); + add(FastRRegisterFunctions.class, FastRRegisterFunctionsNodeGen::create); add(FastrDqrls.class, FastrDqrlsNodeGen::create); add(FastRDebug.class, FastRDebugNodeGen::create); add(FastRSetBreakpoint.class, FastRSetBreakpointNodeGen::create); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/DynLoadFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/DynLoadFunctions.java index 05c9f4f3e4473716fbc0ddd35fec26b398cf8dda..6a5104a15e60b6dfc9fa33e26fb3f0542129f20f 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/DynLoadFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/DynLoadFunctions.java @@ -178,7 +178,7 @@ public class DynLoadFunctions { @TruffleBoundary protected Object getSymbolInfo(String symbol, RAbstractStringVector packageName, boolean withReg) { DLL.RegisteredNativeSymbol rns = DLL.RegisteredNativeSymbol.any(); - DLL.SymbolHandle f = findSymbolNode.execute(RRuntime.asString(symbol), packageName.getDataAt(0), rns); + DLL.SymbolHandle f = findSymbolNode.execute(symbol, packageName.getDataAt(0), rns); SymbolInfo symbolInfo = null; if (f != DLL.SYMBOL_NOT_FOUND) { symbolInfo = new SymbolInfo(rns.getDllInfo(), symbol, f); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java index 4c85b9c1f19661da99173f55948e9751b1052939..e5aa670388dc2c8a23d3a5fbfbdca6a93f0adddc 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java @@ -5,7 +5,7 @@ * * Copyright (c) 1995-2012, The R Core Team * Copyright (c) 2003, The R Foundation - * Copyright (c) 2015, 2017, Oracle and/or its affiliates + * Copyright (c) 2015, 2018, Oracle and/or its affiliates * * All rights reserved. */ @@ -23,6 +23,7 @@ import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.library.fastrGrid.FastRGridExternalLookup; import com.oracle.truffle.r.library.methods.MethodsListDispatchFactory.R_M_setPrimitiveMethodsNodeGen; import com.oracle.truffle.r.library.methods.MethodsListDispatchFactory.R_externalPtrPrototypeObjectNodeGen; @@ -77,6 +78,7 @@ import com.oracle.truffle.r.library.utils.UnzipNodeGen; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RInternalCodeBuiltinNode; import com.oracle.truffle.r.nodes.builtin.base.foreign.CallAndExternalFunctions.DotExternal.CallNamedFunctionNode; +import com.oracle.truffle.r.nodes.function.call.RExplicitCallNode; import com.oracle.truffle.r.nodes.helpers.MaterializeNode; import com.oracle.truffle.r.nodes.objects.GetPrimNameNodeGen; import com.oracle.truffle.r.nodes.objects.NewObjectNodeGen; @@ -91,11 +93,14 @@ import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RExternalPtr; +import com.oracle.truffle.r.runtime.data.RFunction; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.ffi.CallRFFI; +import com.oracle.truffle.r.runtime.ffi.CallRFFI.InvokeCallNode; import com.oracle.truffle.r.runtime.ffi.DLL; +import com.oracle.truffle.r.runtime.ffi.DLL.SymbolHandle; import com.oracle.truffle.r.runtime.ffi.NativeCallInfo; import com.oracle.truffle.r.runtime.ffi.RFFIFactory; import com.oracle.truffle.r.runtime.nmath.distr.Cauchy; @@ -229,9 +234,8 @@ public class CallAndExternalFunctions { * could be invoked by a string but experimentally that situation has never been encountered. */ @RBuiltin(name = ".Call", kind = PRIMITIVE, parameterNames = {".NAME", "...", "PACKAGE"}, behavior = COMPLEX) - public abstract static class DotCall extends LookupAdapter { + public abstract static class DotCall extends Dot { - @Child private CallRFFI.InvokeCallNode callRFFINode = RFFIFactory.getCallRFFI().createInvokeCallNode(); @Child private MaterializeNode materializeNode = MaterializeNode.create(true); static { @@ -640,12 +644,17 @@ public class CallAndExternalFunctions { */ @SuppressWarnings("unused") @Specialization(limit = "2", guards = {"cached == symbol", "builtin == null"}) - protected Object callNamedFunction(RList symbol, RArgsValuesAndNames args, Object packageName, + protected Object callNamedFunction(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, Object packageName, @Cached("symbol") RList cached, @Cached("lookupBuiltin(symbol)") RExternalBuiltinNode builtin, @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo, - @Cached("extractSymbolInfo.execute(symbol)") NativeCallInfo nativeCallInfo) { - return callRFFINode.dispatch(nativeCallInfo, materializeArgs(args.getArguments())); + @Cached("extractSymbolInfo.execute(symbol)") NativeCallInfo nativeCallInfo, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { + if (registeredProfile.profile(isRegisteredRFunction(nativeCallInfo))) { + return explicitCall(frame, nativeCallInfo, args); + } else { + return dispatch(nativeCallInfo, materializeArgs(args.getArguments())); + } } /** @@ -653,24 +662,30 @@ public class CallAndExternalFunctions { * such cases there is this generic version. */ @Specialization(replaces = {"callNamedFunction", "doExternal"}) - protected Object callNamedFunctionGeneric(RList symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") Object packageName, - @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo) { + protected Object callNamedFunctionGeneric(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") Object packageName, + @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { RExternalBuiltinNode builtin = lookupBuiltin(symbol); if (builtin != null) { throw RInternalError.shouldNotReachHere("Cache for .Calls with FastR reimplementation (lookupBuiltin(...) != null) exceeded the limit"); } NativeCallInfo nativeCallInfo = extractSymbolInfo.execute(symbol); - return callRFFINode.dispatch(nativeCallInfo, materializeArgs(args.getArguments())); + if (registeredProfile.profile(isRegisteredRFunction(nativeCallInfo))) { + return explicitCall(frame, nativeCallInfo, args); + } else { + return dispatch(nativeCallInfo, materializeArgs(args.getArguments())); + } } /** * {@code .NAME = string}, no package specified. */ @Specialization - protected Object callNamedFunction(String symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName, + protected Object callNamedFunction(VirtualFrame frame, String symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName, @Cached("createRegisteredNativeSymbol(CallNST)") DLL.RegisteredNativeSymbol rns, - @Cached("create()") DLL.RFindSymbolNode findSymbolNode) { - return callNamedFunctionWithPackage(symbol, args, null, rns, findSymbolNode); + @Cached("create()") DLL.RFindSymbolNode findSymbolNode, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { + return callNamedFunctionWithPackage(frame, symbol, args, null, rns, findSymbolNode, registeredProfile); } /** @@ -678,19 +693,29 @@ public class CallAndExternalFunctions { * define that symbol. */ @Specialization - protected Object callNamedFunctionWithPackage(String symbol, RArgsValuesAndNames args, String packageName, + protected Object callNamedFunctionWithPackage(VirtualFrame frame, String symbol, RArgsValuesAndNames args, String packageName, @Cached("createRegisteredNativeSymbol(CallNST)") DLL.RegisteredNativeSymbol rns, - @Cached("create()") DLL.RFindSymbolNode findSymbolNode) { + @Cached("create()") DLL.RFindSymbolNode findSymbolNode, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { DLL.SymbolHandle func = findSymbolNode.execute(symbol, packageName, rns); if (func == DLL.SYMBOL_NOT_FOUND) { throw error(RError.Message.SYMBOL_NOT_IN_TABLE, symbol, "Call", packageName); } - return callRFFINode.dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), materializeArgs(args.getArguments())); + if (registeredProfile.profile(isRegisteredRFunction(func))) { + return explicitCall(frame, func, args); + } else { + return dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), materializeArgs(args.getArguments())); + } } @Specialization - protected Object callNamedFunctionWithPackage(RExternalPtr symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName) { - return callRFFINode.dispatch(new NativeCallInfo("", symbol.getAddr(), null), materializeArgs(args.getArguments())); + protected Object callNamedFunctionWithPackage(VirtualFrame frame, RExternalPtr symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { + if (registeredProfile.profile(isRegisteredRFunction(symbol))) { + return explicitCall(frame, symbol, args); + } else { + return dispatch(new NativeCallInfo("", symbol.getAddr(), null), materializeArgs(args.getArguments())); + } } @SuppressWarnings("unused") @@ -705,9 +730,9 @@ public class CallAndExternalFunctions { * {@link DotCall}. */ @com.oracle.truffle.r.runtime.builtins.RBuiltin(name = ".External", kind = RBuiltinKind.PRIMITIVE, parameterNames = {".NAME", "...", "PACKAGE"}, behavior = RBehavior.COMPLEX) - public abstract static class DotExternal extends LookupAdapter { + public abstract static class DotExternal extends Dot { - @Child private CallRFFI.InvokeCallNode callRFFINode = RFFIFactory.getCallRFFI().createInvokeCallNode(); + @Child private RExplicitCallNode explicitCall; static { Casts.noCasts(DotExternal.class); @@ -774,15 +799,20 @@ public class CallAndExternalFunctions { @SuppressWarnings("unused") @Specialization(limit = "1", guards = {"cached.symbol == symbol"}) - protected Object callNamedFunction(RList symbol, RArgsValuesAndNames args, Object packageName, - @Cached("new(symbol)") CallNamedFunctionNode cached) { - Object list = encodeArgumentPairList(args, cached.nativeCallInfo.name); - return callRFFINode.dispatch(cached.nativeCallInfo, new Object[]{list}); + protected Object callNamedFunction(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, Object packageName, + @Cached("new(symbol)") CallNamedFunctionNode cached, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { + if (registeredProfile.profile(isRegisteredRFunction(cached.nativeCallInfo))) { + return explicitCall(frame, cached.nativeCallInfo, args); + } else { + Object list = encodeArgumentPairList(args, cached.nativeCallInfo.name); + return dispatch(cached.nativeCallInfo, new Object[]{list}); + } } public static class CallNamedFunctionNode extends Node { public final RList symbol; - final NativeCallInfo nativeCallInfo; + public final NativeCallInfo nativeCallInfo; public CallNamedFunctionNode(RList symbol) { this.symbol = symbol; @@ -792,22 +822,28 @@ public class CallAndExternalFunctions { } @Specialization - protected Object callNamedFunction(String symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName, + protected Object callNamedFunction(VirtualFrame frame, String symbol, RArgsValuesAndNames args, @SuppressWarnings("unused") RMissing packageName, @Cached("createRegisteredNativeSymbol(ExternalNST)") DLL.RegisteredNativeSymbol rns, - @Cached("create()") DLL.RFindSymbolNode findSymbolNode) { - return callNamedFunctionWithPackage(symbol, args, null, rns, findSymbolNode); + @Cached("create()") DLL.RFindSymbolNode findSymbolNode, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { + return callNamedFunctionWithPackage(frame, symbol, args, null, rns, findSymbolNode, registeredProfile); } @Specialization - protected Object callNamedFunctionWithPackage(String symbol, RArgsValuesAndNames args, String packageName, + protected Object callNamedFunctionWithPackage(VirtualFrame frame, String symbol, RArgsValuesAndNames args, String packageName, @Cached("createRegisteredNativeSymbol(ExternalNST)") DLL.RegisteredNativeSymbol rns, - @Cached("create()") DLL.RFindSymbolNode findSymbolNode) { + @Cached("create()") DLL.RFindSymbolNode findSymbolNode, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { DLL.SymbolHandle func = findSymbolNode.execute(symbol, packageName, rns); if (func == DLL.SYMBOL_NOT_FOUND) { throw error(RError.Message.SYMBOL_NOT_IN_TABLE, symbol, "External", packageName); } - Object list = encodeArgumentPairList(args, symbol); - return callRFFINode.dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), new Object[]{list}); + if (registeredProfile.profile(isRegisteredRFunction(func))) { + return explicitCall(frame, func, args); + } else { + Object list = encodeArgumentPairList(args, symbol); + return dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), new Object[]{list}); + } } @Fallback @@ -817,7 +853,7 @@ public class CallAndExternalFunctions { } @RBuiltin(name = ".External2", visibility = CUSTOM, kind = PRIMITIVE, parameterNames = {".NAME", "...", "PACKAGE"}, behavior = COMPLEX) - public abstract static class DotExternal2 extends LookupAdapter { + public abstract static class DotExternal2 extends Dot { private static final Object CALL = "call"; private static final Object RHO = "rho"; /** @@ -830,8 +866,6 @@ public class CallAndExternalFunctions { */ @CompilationFinal private Object op = null; - @Child private CallRFFI.InvokeCallNode callRFFINode = RFFIFactory.getCallRFFI().createInvokeCallNode(); - static { Casts.noCasts(DotExternal2.class); } @@ -883,7 +917,7 @@ public class CallAndExternalFunctions { protected Object callNamedFunction(RList symbol, RArgsValuesAndNames args, Object packageName, @Cached("new(symbol)") CallNamedFunctionNode cached) { Object list = encodeArgumentPairList(args, cached.nativeCallInfo.name); - return callRFFINode.dispatch(cached.nativeCallInfo, new Object[]{CALL, getOp(), list, RHO}); + return dispatch(cached.nativeCallInfo, new Object[]{CALL, getOp(), list, RHO}); } @Specialization @@ -902,7 +936,7 @@ public class CallAndExternalFunctions { throw error(RError.Message.SYMBOL_NOT_IN_TABLE, symbol, "External2", packageName); } Object list = encodeArgumentPairList(args, symbol); - return callRFFINode.dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), new Object[]{CALL, getOp(), list, RHO}); + return dispatch(new NativeCallInfo(symbol, func, rns.getDllInfo()), new Object[]{CALL, getOp(), list, RHO}); } @Fallback @@ -911,6 +945,45 @@ public class CallAndExternalFunctions { } } + private abstract static class Dot extends LookupAdapter { + @Child private InvokeCallNode callRFFINode = RFFIFactory.getCallRFFI().createInvokeCallNode(); + @Child private RExplicitCallNode explicitCall; + + protected Object dispatch(NativeCallInfo nativeCallInfo, Object[] args) { + return callRFFINode.dispatch(nativeCallInfo, args); + } + + protected Object explicitCall(VirtualFrame frame, NativeCallInfo nativeCallInfo, RArgsValuesAndNames args) { + return explicitCall(frame, nativeCallInfo.address, args); + } + + protected Object explicitCall(VirtualFrame frame, RExternalPtr ptr, RArgsValuesAndNames args) { + return explicitCall(frame, ptr.getAddr(), args); + } + + protected Object explicitCall(VirtualFrame frame, SymbolHandle symbolHandle, RArgsValuesAndNames args) { + if (explicitCall == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + explicitCall = insert(RExplicitCallNode.create()); + } + RFunction function = (RFunction) symbolHandle.asTruffleObject(); + return explicitCall.call(frame, function, args); + } + + protected boolean isRegisteredRFunction(NativeCallInfo nativeCallInfo) { + return isRegisteredRFunction(nativeCallInfo.address); + } + + protected boolean isRegisteredRFunction(RExternalPtr ptr) { + DLL.SymbolHandle addr = ptr.getAddr(); + return !addr.isLong() && addr.asTruffleObject() instanceof RFunction; + } + + protected static boolean isRegisteredRFunction(SymbolHandle handle) { + return !handle.isLong() && handle.asTruffleObject() instanceof RFunction; + } + } + @RBuiltin(name = ".External.graphics", kind = PRIMITIVE, parameterNames = {".NAME", "...", "PACKAGE"}, behavior = COMPLEX) public abstract static class DotExternalGraphics extends LookupAdapter { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FortranAndCFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FortranAndCFunctions.java index 7a3f15189cdaa1292cc8b7f73dbeb4aaadf92087..3c2c1590e9d97bc778e344cb3f2b4727435bef30 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FortranAndCFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FortranAndCFunctions.java @@ -11,6 +11,7 @@ */ package com.oracle.truffle.r.nodes.builtin.base.foreign; +import com.oracle.truffle.api.CompilerDirectives; import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; @@ -22,11 +23,13 @@ import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.builtin.base.foreign.FortranAndCFunctionsFactory.FortranResultNamesSetterNodeGen; import com.oracle.truffle.r.nodes.builtin.base.foreign.LookupAdapter.ExtractNativeCallInfoNode; +import com.oracle.truffle.r.nodes.function.call.RExplicitCallNode; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; @@ -34,10 +37,12 @@ import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; import com.oracle.truffle.r.runtime.data.RAttributable; import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RFunction; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.ffi.DLL; +import com.oracle.truffle.r.runtime.ffi.DLL.SymbolHandle; import com.oracle.truffle.r.runtime.ffi.InvokeCNode; import com.oracle.truffle.r.runtime.ffi.NativeCallInfo; import com.oracle.truffle.r.runtime.ffi.RFFIFactory; @@ -77,6 +82,8 @@ public class FortranAndCFunctions { @Child private FortranResultNamesSetter resNamesSetter = FortranResultNamesSetterNodeGen.create(); + @Child private RExplicitCallNode explicitCall; + @Override @TruffleBoundary public RExternalBuiltinNode lookupBuiltin(RList symbol) { @@ -85,28 +92,51 @@ public class FortranAndCFunctions { @SuppressWarnings("unused") @Specialization(limit = "1", guards = {"cached == symbol", "builtin != null"}) - protected Object doExternal(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, byte naok, byte dup, Object rPackage, RMissing encoding, @Cached("symbol") RList cached, + protected Object doFortran(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, byte naok, byte dup, Object rPackage, RMissing encoding, @Cached("symbol") RList cached, @Cached("lookupBuiltin(symbol)") RExternalBuiltinNode builtin) { return resNamesSetter.execute(builtin.call(frame, args), args); } @Specialization(guards = "lookupBuiltin(symbol) == null") - protected RList c(RList symbol, RArgsValuesAndNames args, byte naok, byte dup, @SuppressWarnings("unused") Object rPackage, @SuppressWarnings("unused") RMissing encoding, - @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo) { + protected RList doFortran(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, byte naok, byte dup, @SuppressWarnings("unused") Object rPackage, + @SuppressWarnings("unused") RMissing encoding, + @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { NativeCallInfo nativeCallInfo = extractSymbolInfo.execute(symbol); - return invokeCNode.dispatch(nativeCallInfo, naok, dup, args); + if (registeredProfile.profile(isRegisteredRFunction(nativeCallInfo))) { + if (explicitCall == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + explicitCall = insert(RExplicitCallNode.create()); + } + RFunction function = (RFunction) nativeCallInfo.address.asTruffleObject(); + Object result = explicitCall.call(frame, function, args); + return RDataFactory.createList(new Object[]{result}); + } else { + return invokeCNode.dispatch(nativeCallInfo, naok, dup, args); + } } @Specialization - protected RList c(RAbstractStringVector symbol, RArgsValuesAndNames args, byte naok, byte dup, Object rPackage, @SuppressWarnings("unused") RMissing encoding, - @Cached("create()") DLL.RFindSymbolNode findSymbolNode) { + protected RList doFortran(VirtualFrame frame, RAbstractStringVector symbol, RArgsValuesAndNames args, byte naok, byte dup, Object rPackage, @SuppressWarnings("unused") RMissing encoding, + @Cached("create()") DLL.RFindSymbolNode findSymbolNode, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { String libName = LookupAdapter.checkPackageArg(rPackage); DLL.RegisteredNativeSymbol rns = new DLL.RegisteredNativeSymbol(DLL.NativeSymbolType.Fortran, null, null); DLL.SymbolHandle func = findSymbolNode.execute(symbol.getDataAt(0), libName, rns); if (func == DLL.SYMBOL_NOT_FOUND) { throw error(RError.Message.C_SYMBOL_NOT_IN_TABLE, symbol); } - return invokeCNode.dispatch(new NativeCallInfo(symbol.getDataAt(0), func, rns.getDllInfo()), naok, dup, args); + if (registeredProfile.profile(isRegisteredRFunction(func))) { + if (explicitCall == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + explicitCall = insert(RExplicitCallNode.create()); + } + RFunction function = (RFunction) func.asTruffleObject(); + Object result = explicitCall.call(frame, function, args); + return RDataFactory.createList(new Object[]{result}); + } else { + return invokeCNode.dispatch(new NativeCallInfo(symbol.getDataAt(0), func, rns.getDllInfo()), naok, dup, args); + } } @SuppressWarnings("unused") @@ -114,6 +144,14 @@ public class FortranAndCFunctions { protected Object fallback(Object symbol, Object args, Object naok, Object dup, Object rPackage, Object encoding) { throw LookupAdapter.fallback(this, symbol); } + + protected boolean isRegisteredRFunction(NativeCallInfo nativeCallInfo) { + return isRegisteredRFunction(nativeCallInfo.address); + } + + private static boolean isRegisteredRFunction(SymbolHandle handle) { + return !handle.isLong() && handle.asTruffleObject() instanceof RFunction; + } } public abstract static class FortranResultNamesSetter extends RBaseNode { @@ -161,16 +199,30 @@ public class FortranAndCFunctions { Casts.noCasts(DotC.class); } + @Child private RExplicitCallNode explicitCall; + @Specialization - protected RList c(RList symbol, RArgsValuesAndNames args, byte naok, byte dup, @SuppressWarnings("unused") Object rPackage, @SuppressWarnings("unused") RMissing encoding, - @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo) { + protected RList c(VirtualFrame frame, RList symbol, RArgsValuesAndNames args, byte naok, byte dup, @SuppressWarnings("unused") Object rPackage, @SuppressWarnings("unused") RMissing encoding, + @Cached("new()") ExtractNativeCallInfoNode extractSymbolInfo, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { NativeCallInfo nativeCallInfo = extractSymbolInfo.execute(symbol); - return invokeCNode.dispatch(nativeCallInfo, naok, dup, args); + if (registeredProfile.profile(isRegisteredRFunction(nativeCallInfo))) { + if (explicitCall == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + explicitCall = insert(RExplicitCallNode.create()); + } + RFunction function = (RFunction) nativeCallInfo.address.asTruffleObject(); + Object result = explicitCall.call(frame, function, args); + return RDataFactory.createList(new Object[]{result}); + } else { + return invokeCNode.dispatch(nativeCallInfo, naok, dup, args); + } } @Specialization - protected RList c(RAbstractStringVector symbol, RArgsValuesAndNames args, byte naok, byte dup, Object rPackage, @SuppressWarnings("unused") RMissing encoding, - @Cached("create()") DLL.RFindSymbolNode findSymbolNode) { + protected RList c(VirtualFrame frame, RAbstractStringVector symbol, RArgsValuesAndNames args, byte naok, byte dup, Object rPackage, @SuppressWarnings("unused") RMissing encoding, + @Cached("create()") DLL.RFindSymbolNode findSymbolNode, + @Cached("createBinaryProfile()") ConditionProfile registeredProfile) { String libName = null; if (!(rPackage instanceof RMissing)) { libName = RRuntime.asString(rPackage); @@ -183,7 +235,26 @@ public class FortranAndCFunctions { if (func == DLL.SYMBOL_NOT_FOUND) { throw error(RError.Message.C_SYMBOL_NOT_IN_TABLE, symbol); } - return invokeCNode.dispatch(new NativeCallInfo(symbol.getDataAt(0), func, rns.getDllInfo()), naok, dup, args); + if (registeredProfile.profile(isRegisteredRFunction(func))) { + if (explicitCall == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + explicitCall = insert(RExplicitCallNode.create()); + } + + RFunction function = (RFunction) func.asTruffleObject(); + Object result = explicitCall.call(frame, function, args); + return RDataFactory.createList(new Object[]{result}); + } else { + return invokeCNode.dispatch(new NativeCallInfo(symbol.getDataAt(0), func, rns.getDllInfo()), naok, dup, args); + } + } + + protected boolean isRegisteredRFunction(NativeCallInfo nativeCallInfo) { + return isRegisteredRFunction(nativeCallInfo.address); + } + + private static boolean isRegisteredRFunction(SymbolHandle handle) { + return !handle.isLong() && handle.asTruffleObject() instanceof RFunction; } } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/fastr/FastRRegisterFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/fastr/FastRRegisterFunctions.java new file mode 100644 index 0000000000000000000000000000000000000000..307006798e969fd81a747bcf225d3de41047c809 --- /dev/null +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/fastr/FastRRegisterFunctions.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.fastr; + +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.r.nodes.builtin.NodeWithArgumentCasts; +import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX; +import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; +import com.oracle.truffle.r.runtime.context.RContext; +import com.oracle.truffle.r.runtime.data.RFunction; +import com.oracle.truffle.r.runtime.data.RList; +import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.RStringVector; +import com.oracle.truffle.r.runtime.env.REnvironment; +import com.oracle.truffle.r.runtime.env.REnvironment.PutException; +import com.oracle.truffle.r.runtime.ffi.DLL; +import com.oracle.truffle.r.runtime.ffi.DLL.DLLInfo; +import com.oracle.truffle.r.runtime.ffi.DLL.DotSymbol; +import com.oracle.truffle.r.runtime.ffi.DLL.NativeSymbolType; +import com.oracle.truffle.r.runtime.ffi.DLL.RegisteredNativeSymbol; +import com.oracle.truffle.r.runtime.ffi.DLL.SymbolHandle; +import com.oracle.truffle.r.runtime.ffi.DLL.SymbolInfo; + +/** + * Fake registration of a RFunction as if it was a native function callable by .C, .Call, .External, + * .Fortran. + */ +@RBuiltin(name = ".fastr.register.functions", kind = PRIMITIVE, parameterNames = {"library", "env", "nstOrd", "functions"}, behavior = COMPLEX) +public abstract class FastRRegisterFunctions extends RBuiltinNode.Arg4 { + + static { + NodeWithArgumentCasts.Casts.noCasts(FastRRegisterFunctions.class); + } + + /** + * @param library library under which to register the given function + * @param env the environment to use + * @param nstOrd see {@link NativeSymbolType} , .C - 0, .Call - 1, .Fortran - 2, .External 3 + * @param functions named list of RFunction-s + */ + @Specialization + protected Object register(String library, REnvironment env, int nstOrd, RList functions) { + try { + RStringVector names = functions.getNames(); + + DLLInfo dllInfo = DLL.findLibrary(library); + if (dllInfo == null) { + dllInfo = DLL.createSyntheticLib(RContext.getInstance(), library); + } + + DotSymbol[] symbols = new DotSymbol[names.getLength()]; + SymbolHandle[] symbolHandles = new SymbolHandle[names.getLength()]; + + for (int i = 0; i < names.getLength(); i++) { + assert functions.getDataAt(i) instanceof RFunction : " only RFunction elements are allowed in the functions list: " + functions.getDataAt(i); + RFunction fun = (RFunction) functions.getDataAt(i); + String name = names.getDataAt(i); + assert !name.isEmpty() : "each element in functions list has to be named"; + + symbolHandles[i] = new SymbolHandle(fun); + symbols[i] = new DotSymbol(name, symbolHandles[i], 0); + } + + NativeSymbolType nst = NativeSymbolType.values()[nstOrd]; + DotSymbol[] oldSymbols = dllInfo.getNativeSymbols(nst); + DotSymbol[] newSymbols; + if (oldSymbols == null) { + newSymbols = symbols; + } else { + newSymbols = new DotSymbol[oldSymbols.length + symbols.length]; + System.arraycopy(oldSymbols, 0, newSymbols, 0, oldSymbols.length); + System.arraycopy(symbols, 0, newSymbols, oldSymbols.length, symbols.length); + } + dllInfo.setNativeSymbols(nstOrd, newSymbols); + + assign(symbols, dllInfo, symbolHandles, nst, env); + } catch (REnvironment.PutException ex) { + throw error(ex); + } + return RNull.instance; + } + + @TruffleBoundary + private void assign(DotSymbol[] symbols, DLLInfo dllInfo, SymbolHandle[] symbolHandles, NativeSymbolType nst, REnvironment env) throws PutException { + for (int i = 0; i < symbols.length; i++) { + SymbolInfo si = new SymbolInfo(dllInfo, symbols[i].name, symbolHandles[i]); + RList symbolObject = si.createRSymbolObject(new RegisteredNativeSymbol(nst, symbols[i], dllInfo), true); + env.put(symbols[i].name, symbolObject); + } + } + + @Specialization + protected Object register(String library, REnvironment env, double nstOrd, RList functions) { + return register(library, env, (int) nstOrd, functions); + } + +} diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java index 5a2514b0f0fc968a7e03f976a11a4c0cbde76c45..fab1bc5db8463a71a681944a621c5e3dcd52c946 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java @@ -42,6 +42,7 @@ import com.oracle.truffle.r.runtime.data.CharSXPWrapper; import com.oracle.truffle.r.runtime.data.NativeDataAccess.CustomNativeMirror; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RExternalPtr; +import com.oracle.truffle.r.runtime.data.RFunction; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RObject; @@ -113,7 +114,9 @@ public class DLL { RootCallTarget closeCallTarget = DLCloseRootNode.create(contextArg); for (int i = 1; i < list.size(); i++) { DLLInfo dllInfo = list.get(i); - closeCallTarget.call(dllInfo.handle); + if (!dllInfo.isSynthetic()) { + closeCallTarget.call(dllInfo.handle); + } } } list = null; @@ -209,8 +212,13 @@ public class DLL { private final DotSymbol[][] nativeSymbols = new DotSymbol[NativeSymbolType.values().length][]; private ArrayList<CEntry> cEntryTable = null; private final HashSet<String> unsuccessfulLookups = new HashSet<>(); + /** + * A synthetic DLLInfo faking {@link RFunction}-s as if they were real native symbols to + * .Call etc. + */ + private final boolean syntheticHandle; - private DLLInfo(String name, String path, boolean dynamicLookup, Object handle) { + private DLLInfo(String name, String path, boolean dynamicLookup, Object handle, boolean syntheticHandle) { this.id = ID.getAndIncrement(); this.name = name; this.nameSXP = CharSXPWrapper.create(name); @@ -218,10 +226,15 @@ public class DLL { this.pathSXP = CharSXPWrapper.create(path); this.dynamicLookup = dynamicLookup; this.handle = handle; + this.syntheticHandle = syntheticHandle; } private static DLLInfo create(String name, String path, boolean dynamicLookup, Object handle, boolean addToList) { - DLLInfo result = new DLLInfo(name, path, dynamicLookup, handle); + return create(name, path, dynamicLookup, handle, addToList, false); + } + + private static DLLInfo create(String name, String path, boolean dynamicLookup, Object handle, boolean addToList, boolean syntheticHandle) { + DLLInfo result = new DLLInfo(name, path, dynamicLookup, handle, syntheticHandle); if (addToList) { ContextStateImpl contextState = getContextState(); contextState.list.add(result); @@ -237,6 +250,14 @@ public class DLL { return handle == null; } + /** + * Determines whether this is a synthetic {@link DLLInfo} faking {@link RFunction}-s as if + * they were real native symbols. + */ + private boolean isSynthetic() { + return syntheticHandle; + } + public void setNativeSymbols(int nstOrd, DotSymbol[] symbols) { nativeSymbols[nstOrd] = symbols; } @@ -473,6 +494,12 @@ public class DLL { dllContext.addLibR(DLLInfo.create(libName(path), path, true, handle, false)); } + public static DLLInfo createSyntheticLib(RContext context, String library) { + DLLInfo dllInfo = DLLInfo.create(library, library, true, new Object(), false, true); + context.stateDLL.list.add(dllInfo); + return dllInfo; + } + public static String libName(String absPath) { File file = new File(absPath); String name = file.getName(); @@ -714,6 +741,9 @@ public class DLL { if (f != SYMBOL_NOT_FOUND) { return f; } + if (dllInfo.isSynthetic()) { + return SYMBOL_NOT_FOUND; + } // TODO: there is a weird interaction with namespace environments that makes this not // true in all cases diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index dc4af52603c82f1790d5d045015e208a0eb9c442..4c72153998acf3103c52b7ba99ba17c9eaba3104 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -142999,6 +142999,42 @@ a b c d a b c d e 1 2 3 4 5 +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotC# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 0, list(cfun1=function() {1})); .C(cfun1)==1 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotC# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 0, list(cfun2=function() {2}, cfun3=function() {3}));.fastr.register.functions('testLib', environment(), 0, list(cfun4=function() {4})); .C(cfun2)==2 && .C(cfun3)==3 && .C(cfun4)==4 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotCall# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 1, list(callfun1=function() {1})); .Call(callfun1)==1 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotCall# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 1, list(callfun2=function() {2}, callfun3=function() {3}));.fastr.register.functions('testLib', environment(), 1, list(callfun4=function() {4})); .Call(callfun2)==2 && .Call(callfun3)==3 && .Call(callfun4)==4 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotCall# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 1, list(callfunptr=function() {5})); assign('callptr', getNativeSymbolInfo('callfunptr', 'testLib')$address); .Call(callptr)==5 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotExternal# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 3, list(externalfun1=function() {1})); .External(externalfun1)==1 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotExternal# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 3, list(externalfun2=function() {2}, externalfun3=function() {3}));.fastr.register.functions('testLib', environment(), 3, list(externalfun4=function() {4})); .External(externalfun2)==2 && .External(externalfun3)==3 && .External(externalfun4)==4 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotFortran# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 2, list(ffun5=function() {5})); .Fortran(ffun5)==5 } +[1] TRUE + +##com.oracle.truffle.r.test.library.base.foreign.TestRegisterFunction.testDotFortran# +#if (!any(R.version$engine == "FastR")) { T } else { .fastr.register.functions('testLib', environment(), 2, list(ffun6=function() {6}, ffun7=function() {7}));.fastr.register.functions('testLib', environment(), 2, list(ffun8=function() {8})); .Fortran(ffun6)==6 && .Fortran(ffun7)==7 && .Fortran(ffun8)==8 } +[1] TRUE + ##com.oracle.truffle.r.test.library.fastr.TestChannels.dummyTest# #42 [1] 42 diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/foreign/TestRegisterFunction.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/foreign/TestRegisterFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..3fb6b6d02397eb4276e9d275723929ea3d7b2fe6 --- /dev/null +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/foreign/TestRegisterFunction.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.test.library.base.foreign; + +import org.junit.Test; + +import com.oracle.truffle.r.test.TestBase; + +// Checkstyle: stop line length check + +public class TestRegisterFunction extends TestBase { + + @Test + public void testDotCall() { + assertEvalFastR(".fastr.register.functions('testLib', environment(), 1, list(callfun1=function() {1})); .Call(callfun1)==1", "T"); + + assertEvalFastR(".fastr.register.functions('testLib', environment(), 1, list(callfun2=function() {2}, callfun3=function() {3}));" + + ".fastr.register.functions('testLib', environment(), 1, list(callfun4=function() {4}));" + " .Call(callfun2)==2 && .Call(callfun3)==3 && .Call(callfun4)==4", "T"); + + assertEvalFastR(".fastr.register.functions('testLib', environment(), 1, list(callfunptr=function() {5})); " + + "assign('callptr', getNativeSymbolInfo('callfunptr', 'testLib')$address); " + + ".Call(callptr)==5", "T"); + } + + @Test + public void testDotExternal() { + assertEvalFastR(".fastr.register.functions('testLib', environment(), 3, list(externalfun1=function() {1})); .External(externalfun1)==1", "T"); + + assertEvalFastR(".fastr.register.functions('testLib', environment(), 3, list(externalfun2=function() {2}, externalfun3=function() {3}));" + + ".fastr.register.functions('testLib', environment(), 3, list(externalfun4=function() {4}));" + + " .External(externalfun2)==2 && .External(externalfun3)==3 && .External(externalfun4)==4", + "T"); + } + + @Test + public void testDotC() { + assertEvalFastR(".fastr.register.functions('testLib', environment(), 0, list(cfun1=function() {1})); .C(cfun1)==1", "T"); + + assertEvalFastR(".fastr.register.functions('testLib', environment(), 0, list(cfun2=function() {2}, cfun3=function() {3}));" + + ".fastr.register.functions('testLib', environment(), 0, list(cfun4=function() {4}));" + " .C(cfun2)==2 && .C(cfun3)==3 && .C(cfun4)==4", + "T"); + } + + @Test + public void testDotFortran() { + assertEvalFastR(".fastr.register.functions('testLib', environment(), 2, list(ffun5=function() {5})); .Fortran(ffun5)==5", "T"); + + assertEvalFastR(".fastr.register.functions('testLib', environment(), 2, list(ffun6=function() {6}, ffun7=function() {7}));" + + ".fastr.register.functions('testLib', environment(), 2, list(ffun8=function() {8}));" + + " .Fortran(ffun6)==6 && .Fortran(ffun7)==7 && .Fortran(ffun8)==8", + "T"); + } +}