diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java index df43310c74523980ba4d90e8568472323c41ebd5..66b3507413d878487fd6c5b3e4c496e0dddc5137 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java @@ -42,6 +42,7 @@ import com.oracle.truffle.r.nodes.function.FunctionExpressionNode; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RFunction; +import com.oracle.truffle.r.runtime.data.RPromise; import com.oracle.truffle.r.runtime.env.REnvironment; import qir.ast.*; @@ -98,17 +99,20 @@ public final class QIRInterface { * Resolves the free variables of a query. * * @param arg The query to be executed in the targeted database - * @param frame The environment of execution + * @param argFrame The environment of execution * @return The closed query */ - public static final QIRNode normalize(final QIRNode arg, final Frame frame) { + public static final QIRNode normalize(final QIRNode arg, final Frame argFrame) { QIRNode query = arg; final SourceSection dummy = Source.newBuilder("").name("QIR interface").mimeType(RRuntime.R_APP_MIME).build().createUnavailableSection(); for (Map<String, QIRVariable> fvs = QIRDriver.getFreeVars(query); !fvs.isEmpty(); fvs = QIRDriver.getFreeVars(query)) for (final QIRVariable fv : fvs.values()) { final String varName = fv.id; - final FrameSlot varSlot = frame.getFrameDescriptor().findFrameSlot(varName); + Frame frame = argFrame; + FrameSlot varSlot = frame.getFrameDescriptor().findFrameSlot(varName); + for (; varSlot == null && frame.getArguments()[4] instanceof Frame; frame = (Frame) frame.getArguments()[4]) + varSlot = ((Frame) frame.getArguments()[4]).getFrameDescriptor().findFrameSlot(varName); query = new QIRApply(dummy, new QIRLambda(dummy, null, new QIRVariable(dummy, varName, varSlot), query, new FrameDescriptor()), RToQIRType(fv.sourceSection, varSlot != null ? frame.getValue(varSlot) : RContext.lookupBuiltin(varName))); } @@ -234,26 +238,38 @@ public final class QIRInterface { } if (value instanceof RFunction) { final RFunction fun = (RFunction) value; - if (fun.isBuiltin()) - switch (fun.getName()) { - case "new.env": - return new QIRLambda(src, "new.env", new QIRVariable(null, "_", null), QIRTnil.instance, new FrameDescriptor()); - case "return": - case "(": { - final QIRVariable x = new QIRVariable(null, "x", null); - return new QIRLambda(src, "identity", x, x, new FrameDescriptor()); - } - case "c": { - // TODO: This works only for lists with one element - final QIRVariable x = new QIRVariable(null, "x", null); - return new QIRLambda(src, "lcons", x, new QIRLcons(null, x, QIRLnil.instance), new FrameDescriptor()); - } - case "sum": - return new QIRBuiltin(src, "sum"); - default: - throw new RuntimeException("Unsupported value: " + value + " : " + value.getClass()); + switch (fun.getName()) { + case "new.env": + return new QIRLambda(src, "new.env", new QIRVariable(null, "_", null), QIRTnil.instance, new FrameDescriptor()); + case "return": + case "(": + case "query.force": { + final QIRVariable x = new QIRVariable(null, "x", null); + return new QIRLambda(src, "identity", x, x, new FrameDescriptor()); } - return RFunctionToQIRType(src, fun.getName(), (FunctionDefinitionNode) fun.getRootNode()); + case "c": { + // TODO: This works only for lists with one element + final QIRVariable x = new QIRVariable(null, "x", null); + return new QIRLambda(src, "lcons", x, new QIRLcons(null, x, QIRLnil.instance), new FrameDescriptor()); + } + case "sum": + return new QIRBuiltin(src, "sum"); + case "new.tableRef": + final QIRVariable tableName = new QIRVariable(null, "__tmp__"); + final QIRVariable schemaName = new QIRVariable(null, "__tmp2__"); + final QIRVariable dummyVar = new QIRVariable(null, "__dummy__"); + return new QIRLambda(null, null, tableName, + new QIRLambda(null, null, dummyVar, new QIRLambda(null, null, dummyVar, + new QIRLambda(null, null, schemaName, new QIRTable<>(null, tableName, schemaName, null), new FrameDescriptor()), new FrameDescriptor()), + new FrameDescriptor()), + new FrameDescriptor()); + default: + return RFunctionToQIRType(src, fun.getName(), (FunctionDefinitionNode) fun.getRootNode()); + } + } + if (value instanceof RPromise) { + final RPromise fun = (RPromise) value; + return RToQIRType(src, fun.getValue()); } throw new RuntimeException("Unsupported value: " + value); } @@ -265,7 +281,7 @@ public final class QIRInterface { if (args.length == 0) return new QIRLambda(src, funName, null, res, new FrameDescriptor()); - for (int i = 0; i < args.length; i++) + for (int i = args.length - 1; i >= 0; i--) res = new QIRLambda(src, funName, new QIRVariable(null, args[i], null), res, new FrameDescriptor()); return res; } catch (UnsupportedOperationException e) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java index 3f65ed370b7e1c6ade125bfdecfffffa50bb4d3d..16a0667b4c0d5e9a43532dc602dcd8adc02adf26 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java @@ -26,6 +26,7 @@ import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; import qir.ast.*; import qir.ast.data.*; import qir.ast.expression.QIRNull; +import qir.ast.expression.QIRString; import qir.ast.expression.arithmetic.*; import qir.ast.expression.logic.*; import qir.ast.expression.relational.*; @@ -134,6 +135,12 @@ public final class QIRTranslateVisitor implements RSyntaxNodeVisitor<QIRNode> { @Override public final QIRNode visit(final RCallNode call) { + // TODO: Remove this hack + try { + return visitBuiltin(call.getSourceSection(), ((ReadVariableNode) call.getFunction()).getIdentifier(), + Arrays.stream(call.getArguments().getArguments()).map(arg -> arg.accept(this)).collect(Collectors.toList())); + } catch (final RuntimeException e) { + } final QIRNode fun = call.getFunction().asRSyntaxNode().accept(this); final RSyntaxNode[] args = call.getArguments().getArguments(); final int nbArgs = args.length; @@ -148,37 +155,39 @@ public final class QIRTranslateVisitor implements RSyntaxNodeVisitor<QIRNode> { @Override public final QIRNode visit(final RCallSpecialNode dot) { - final String name = dot.expectedFunction.getName(); - final List<QIRNode> nodes = Arrays.stream(dot.getSyntaxArguments()).map(arg -> ((RSyntaxNode) arg).accept(this)).collect(Collectors.toList()); + return visitBuiltin(dot.getSourceSection(), dot.expectedFunction.getName(), Arrays.stream(dot.getSyntaxArguments()).map(arg -> ((RSyntaxNode) arg).accept(this)).collect(Collectors.toList())); + } + + private static final QIRNode visitBuiltin(final SourceSection src, final String name, final List<QIRNode> args) { switch (name) { case "$": - return new QIRTdestr(dot.getSourceSection(), nodes.get(0), (String) ((ConstantNode) dot.getSyntaxArguments()[1]).getValue()); + return new QIRTdestr(src, args.get(0), ((QIRString) args.get(1)).value); case "+": - return QIRPlusNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRPlusNodeGen.create(src, args.get(0), args.get(1)); case "-": - return QIRMinusNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRMinusNodeGen.create(src, args.get(0), args.get(1)); case "*": - return QIRStarNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRStarNodeGen.create(src, args.get(0), args.get(1)); case "/": - return QIRDivNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRDivNodeGen.create(src, args.get(0), args.get(1)); case "==": - return QIREqualNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIREqualNodeGen.create(src, args.get(0), args.get(1)); case "&": case "&&": - return QIRAndNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRAndNodeGen.create(src, args.get(0), args.get(1)); case "|": case "||": - return QIROrNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIROrNodeGen.create(src, args.get(0), args.get(1)); case "<=": - return QIRLowerOrEqualNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRLowerOrEqualNodeGen.create(src, args.get(0), args.get(1)); case "<": - return QIRLowerThanNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRLowerThanNodeGen.create(src, args.get(0), args.get(1)); case ">=": - return QIRNotNodeGen.create(dot.getSourceSection(), QIRLowerThanNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1))); + return QIRNotNodeGen.create(src, QIRLowerThanNodeGen.create(src, args.get(0), args.get(1))); case ">": - return QIRNotNodeGen.create(dot.getSourceSection(), QIRLowerOrEqualNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1))); + return QIRNotNodeGen.create(src, QIRLowerOrEqualNodeGen.create(src, args.get(0), args.get(1))); case "!": - return QIRNotNodeGen.create(dot.getSourceSection(), nodes.get(0)); + return QIRNotNodeGen.create(src, args.get(0)); default: throw new RuntimeException("Unknown call special node: " + name); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java index bb96aeafc0b8c75d4ce0fb09dd4a5ee85b12308c..5644e696343ea5ff93c51aa75161ddfcbf5ff9b5 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java @@ -26,14 +26,17 @@ import com.oracle.truffle.api.CompilerDirectives.*; import com.oracle.truffle.api.frame.*; import com.oracle.truffle.api.nodes.*; import com.oracle.truffle.api.source.SourceSection; +import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.env.REnvironment; import com.oracle.truffle.r.runtime.env.REnvironment.PutException; import com.oracle.truffle.r.runtime.nodes.RSourceSectionNode; +import com.oracle.truffle.r.runtime.nodes.RSyntaxElement; +import com.oracle.truffle.r.runtime.nodes.RSyntaxFunction; @NodeInfo(shortName = "query", description = "The node representing a query") -public final class RQIRWrapperNode extends RSourceSectionNode { +public final class RQIRWrapperNode extends RSourceSectionNode implements RSyntaxFunction { // The unique identifier of the query public final int id; @@ -60,4 +63,24 @@ public final class RQIRWrapperNode extends RSourceSectionNode { } return res; } + + @Override + public ArgumentsSignature getSyntaxSignature() { + return null; + } + + @Override + public RSyntaxElement[] getSyntaxArgumentDefaults() { + return null; + } + + @Override + public RSyntaxElement getSyntaxBody() { + return null; + } + + @Override + public String getSyntaxDebugName() { + return null; + } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java index c3324c32837b8778b838282b4058af4758f24d30..fdcc9ab41b1abd3ce04cd6e09d6a1b94a13305f7 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java @@ -26,11 +26,14 @@ import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.NodeInfo; import com.oracle.truffle.api.object.DynamicObject; import com.oracle.truffle.api.source.SourceSection; +import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.nodes.RSourceSectionNode; +import com.oracle.truffle.r.runtime.nodes.RSyntaxElement; +import com.oracle.truffle.r.runtime.nodes.RSyntaxFunction; import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; @NodeInfo(shortName = "Query") -public abstract class RQueryNode extends RSourceSectionNode implements RSyntaxNode { +public abstract class RQueryNode extends RSourceSectionNode implements RSyntaxNode, RSyntaxFunction { public RQueryNode(final SourceSection src) { super(src); } @@ -39,4 +42,24 @@ public abstract class RQueryNode extends RSourceSectionNode implements RSyntaxNo public final DynamicObject execute(final VirtualFrame frame) { throw new RuntimeException("We should not execute a RQueryNode directly."); } + + @Override + public ArgumentsSignature getSyntaxSignature() { + return null; + } + + @Override + public RSyntaxElement[] getSyntaxArgumentDefaults() { + return null; + } + + @Override + public RSyntaxElement getSyntaxBody() { + return null; + } + + @Override + public String getSyntaxDebugName() { + return null; + } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java index 1eb02abccfe7c295dd4749a7c44cbd72ccf2b28a..cfe15724b76fe74c28bf3856a902f0cd0e6eb56a 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java @@ -66,7 +66,7 @@ public final class RQueryVisitor implements RSyntaxNodeVisitor<RSyntaxNode> { @Override public final RSyntaxNode visit(final FunctionExpressionNode fun) { - return FunctionExpressionNode.create(fun.getSourceSection(), fun.getCallTarget()); + return fun; } @Override diff --git a/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g b/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g index 81db5e22707262ebf2481653a7d0237039483d6f..ab20ce2a422a7f45914ec0230a19ec3f3f77d21d 100644 --- a/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g +++ b/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g @@ -186,7 +186,7 @@ root_function [String name] returns [RootCallTarget v] throw RInternalError.shouldNotReachHere("not at EOF after parsing deserialized function"); } } - : n_ op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.rootFunction(src($op, last()), params, $body.v, name); } + : n_ op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.rootFunction(src($op, last()), params, builder.handleQueries($body.v), name); } ; statement returns [T v] @@ -264,7 +264,7 @@ repeat_expr returns [T v] function [T assignedTo] returns [T v] @init { List<Argument<T>> params = new ArrayList<>(); } - : op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.function(src($op, last()), params, $body.v, assignedTo); } + : op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.function(src($op, last()), params, builder.handleQueries($body.v), assignedTo); } ; par_decl [List<Argument<T>> l] diff --git a/com.oracle.truffle.r.test/tests/pgsql/QueryNested.R b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.R new file mode 100644 index 0000000000000000000000000000000000000000..27e8b7aaaf37ca9bbb58728102821838023c81ab --- /dev/null +++ b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.R @@ -0,0 +1,25 @@ +# Returns the exchange rate between rfrom and rto +getRate = function(rfrom, rto) +{ + change = new.tableRef("change", "PostgreSQL", "postgre.config", "public") + rate = query.force(where(function (r) r$cfrom == rfrom && r$cto == rto, + from(change))) + if (rfrom == rto) 1 else rate$change +} + +# Returns the names of employees earning at least minSalary in the curr +# currency +atLeast = function(minSalary, curr) +{ + emp = new.tableRef("emp", "PostgreSQL", "postgre.config", "public") + select(function (e) { r = new.env() + r$name = e$ename + r }, + where(function (e) e$sal >= minSalary * getRate("USD", curr), + from(emp))) +} + +richUSPeople = atLeast(2000, "USD") +richEURPeople = atLeast(2000, "EUR") +print(query.force(richUSPeople)) +print(query.force(richEURPeople)) diff --git a/com.oracle.truffle.r.test/tests/pgsql/QueryNested.out b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.out new file mode 100644 index 0000000000000000000000000000000000000000..53da7150c82f6978dfcb1dec73b98d72aa03d3cf --- /dev/null +++ b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.out @@ -0,0 +1,18 @@ + name +1 SMITH +2 ALLEN +3 WARD +4 JONES +5 SCOTT +6 ADAMS +7 MILLER + name +1 SMITH +2 ALLEN +3 WARD +4 JONES +5 SCOTT +6 ADAMS +7 FORD +8 MILLER +