From 1eb0cc1abce4435f069727d9356cb14520862f94 Mon Sep 17 00:00:00 2001 From: Julien Lopez <julien.lopez@lri.fr> Date: Thu, 23 Feb 2017 18:47:20 +0100 Subject: [PATCH] Fixes in translation --- .../r/nodes/qirinterface/QIRInterface.java | 62 ++++++++++++------- .../qirinterface/QIRTranslateVisitor.java | 39 +++++++----- .../r/nodes/query/RQIRWrapperNode.java | 25 +++++++- .../truffle/r/nodes/query/RQueryNode.java | 25 +++++++- .../truffle/r/nodes/query/RQueryVisitor.java | 2 +- .../src/com/oracle/truffle/r/parser/R.g | 4 +- .../tests/pgsql/QueryNested.R | 25 ++++++++ .../tests/pgsql/QueryNested.out | 18 ++++++ 8 files changed, 157 insertions(+), 43 deletions(-) create mode 100644 com.oracle.truffle.r.test/tests/pgsql/QueryNested.R create mode 100644 com.oracle.truffle.r.test/tests/pgsql/QueryNested.out diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java index df43310c74..66b3507413 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java @@ -42,6 +42,7 @@ import com.oracle.truffle.r.nodes.function.FunctionExpressionNode; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RFunction; +import com.oracle.truffle.r.runtime.data.RPromise; import com.oracle.truffle.r.runtime.env.REnvironment; import qir.ast.*; @@ -98,17 +99,20 @@ public final class QIRInterface { * Resolves the free variables of a query. * * @param arg The query to be executed in the targeted database - * @param frame The environment of execution + * @param argFrame The environment of execution * @return The closed query */ - public static final QIRNode normalize(final QIRNode arg, final Frame frame) { + public static final QIRNode normalize(final QIRNode arg, final Frame argFrame) { QIRNode query = arg; final SourceSection dummy = Source.newBuilder("").name("QIR interface").mimeType(RRuntime.R_APP_MIME).build().createUnavailableSection(); for (Map<String, QIRVariable> fvs = QIRDriver.getFreeVars(query); !fvs.isEmpty(); fvs = QIRDriver.getFreeVars(query)) for (final QIRVariable fv : fvs.values()) { final String varName = fv.id; - final FrameSlot varSlot = frame.getFrameDescriptor().findFrameSlot(varName); + Frame frame = argFrame; + FrameSlot varSlot = frame.getFrameDescriptor().findFrameSlot(varName); + for (; varSlot == null && frame.getArguments()[4] instanceof Frame; frame = (Frame) frame.getArguments()[4]) + varSlot = ((Frame) frame.getArguments()[4]).getFrameDescriptor().findFrameSlot(varName); query = new QIRApply(dummy, new QIRLambda(dummy, null, new QIRVariable(dummy, varName, varSlot), query, new FrameDescriptor()), RToQIRType(fv.sourceSection, varSlot != null ? frame.getValue(varSlot) : RContext.lookupBuiltin(varName))); } @@ -234,26 +238,38 @@ public final class QIRInterface { } if (value instanceof RFunction) { final RFunction fun = (RFunction) value; - if (fun.isBuiltin()) - switch (fun.getName()) { - case "new.env": - return new QIRLambda(src, "new.env", new QIRVariable(null, "_", null), QIRTnil.instance, new FrameDescriptor()); - case "return": - case "(": { - final QIRVariable x = new QIRVariable(null, "x", null); - return new QIRLambda(src, "identity", x, x, new FrameDescriptor()); - } - case "c": { - // TODO: This works only for lists with one element - final QIRVariable x = new QIRVariable(null, "x", null); - return new QIRLambda(src, "lcons", x, new QIRLcons(null, x, QIRLnil.instance), new FrameDescriptor()); - } - case "sum": - return new QIRBuiltin(src, "sum"); - default: - throw new RuntimeException("Unsupported value: " + value + " : " + value.getClass()); + switch (fun.getName()) { + case "new.env": + return new QIRLambda(src, "new.env", new QIRVariable(null, "_", null), QIRTnil.instance, new FrameDescriptor()); + case "return": + case "(": + case "query.force": { + final QIRVariable x = new QIRVariable(null, "x", null); + return new QIRLambda(src, "identity", x, x, new FrameDescriptor()); } - return RFunctionToQIRType(src, fun.getName(), (FunctionDefinitionNode) fun.getRootNode()); + case "c": { + // TODO: This works only for lists with one element + final QIRVariable x = new QIRVariable(null, "x", null); + return new QIRLambda(src, "lcons", x, new QIRLcons(null, x, QIRLnil.instance), new FrameDescriptor()); + } + case "sum": + return new QIRBuiltin(src, "sum"); + case "new.tableRef": + final QIRVariable tableName = new QIRVariable(null, "__tmp__"); + final QIRVariable schemaName = new QIRVariable(null, "__tmp2__"); + final QIRVariable dummyVar = new QIRVariable(null, "__dummy__"); + return new QIRLambda(null, null, tableName, + new QIRLambda(null, null, dummyVar, new QIRLambda(null, null, dummyVar, + new QIRLambda(null, null, schemaName, new QIRTable<>(null, tableName, schemaName, null), new FrameDescriptor()), new FrameDescriptor()), + new FrameDescriptor()), + new FrameDescriptor()); + default: + return RFunctionToQIRType(src, fun.getName(), (FunctionDefinitionNode) fun.getRootNode()); + } + } + if (value instanceof RPromise) { + final RPromise fun = (RPromise) value; + return RToQIRType(src, fun.getValue()); } throw new RuntimeException("Unsupported value: " + value); } @@ -265,7 +281,7 @@ public final class QIRInterface { if (args.length == 0) return new QIRLambda(src, funName, null, res, new FrameDescriptor()); - for (int i = 0; i < args.length; i++) + for (int i = args.length - 1; i >= 0; i--) res = new QIRLambda(src, funName, new QIRVariable(null, args[i], null), res, new FrameDescriptor()); return res; } catch (UnsupportedOperationException e) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java index 3f65ed370b..16a0667b4c 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRTranslateVisitor.java @@ -26,6 +26,7 @@ import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; import qir.ast.*; import qir.ast.data.*; import qir.ast.expression.QIRNull; +import qir.ast.expression.QIRString; import qir.ast.expression.arithmetic.*; import qir.ast.expression.logic.*; import qir.ast.expression.relational.*; @@ -134,6 +135,12 @@ public final class QIRTranslateVisitor implements RSyntaxNodeVisitor<QIRNode> { @Override public final QIRNode visit(final RCallNode call) { + // TODO: Remove this hack + try { + return visitBuiltin(call.getSourceSection(), ((ReadVariableNode) call.getFunction()).getIdentifier(), + Arrays.stream(call.getArguments().getArguments()).map(arg -> arg.accept(this)).collect(Collectors.toList())); + } catch (final RuntimeException e) { + } final QIRNode fun = call.getFunction().asRSyntaxNode().accept(this); final RSyntaxNode[] args = call.getArguments().getArguments(); final int nbArgs = args.length; @@ -148,37 +155,39 @@ public final class QIRTranslateVisitor implements RSyntaxNodeVisitor<QIRNode> { @Override public final QIRNode visit(final RCallSpecialNode dot) { - final String name = dot.expectedFunction.getName(); - final List<QIRNode> nodes = Arrays.stream(dot.getSyntaxArguments()).map(arg -> ((RSyntaxNode) arg).accept(this)).collect(Collectors.toList()); + return visitBuiltin(dot.getSourceSection(), dot.expectedFunction.getName(), Arrays.stream(dot.getSyntaxArguments()).map(arg -> ((RSyntaxNode) arg).accept(this)).collect(Collectors.toList())); + } + + private static final QIRNode visitBuiltin(final SourceSection src, final String name, final List<QIRNode> args) { switch (name) { case "$": - return new QIRTdestr(dot.getSourceSection(), nodes.get(0), (String) ((ConstantNode) dot.getSyntaxArguments()[1]).getValue()); + return new QIRTdestr(src, args.get(0), ((QIRString) args.get(1)).value); case "+": - return QIRPlusNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRPlusNodeGen.create(src, args.get(0), args.get(1)); case "-": - return QIRMinusNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRMinusNodeGen.create(src, args.get(0), args.get(1)); case "*": - return QIRStarNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRStarNodeGen.create(src, args.get(0), args.get(1)); case "/": - return QIRDivNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRDivNodeGen.create(src, args.get(0), args.get(1)); case "==": - return QIREqualNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIREqualNodeGen.create(src, args.get(0), args.get(1)); case "&": case "&&": - return QIRAndNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRAndNodeGen.create(src, args.get(0), args.get(1)); case "|": case "||": - return QIROrNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIROrNodeGen.create(src, args.get(0), args.get(1)); case "<=": - return QIRLowerOrEqualNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRLowerOrEqualNodeGen.create(src, args.get(0), args.get(1)); case "<": - return QIRLowerThanNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1)); + return QIRLowerThanNodeGen.create(src, args.get(0), args.get(1)); case ">=": - return QIRNotNodeGen.create(dot.getSourceSection(), QIRLowerThanNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1))); + return QIRNotNodeGen.create(src, QIRLowerThanNodeGen.create(src, args.get(0), args.get(1))); case ">": - return QIRNotNodeGen.create(dot.getSourceSection(), QIRLowerOrEqualNodeGen.create(dot.getSourceSection(), nodes.get(0), nodes.get(1))); + return QIRNotNodeGen.create(src, QIRLowerOrEqualNodeGen.create(src, args.get(0), args.get(1))); case "!": - return QIRNotNodeGen.create(dot.getSourceSection(), nodes.get(0)); + return QIRNotNodeGen.create(src, args.get(0)); default: throw new RuntimeException("Unknown call special node: " + name); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java index bb96aeafc0..5644e69634 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQIRWrapperNode.java @@ -26,14 +26,17 @@ import com.oracle.truffle.api.CompilerDirectives.*; import com.oracle.truffle.api.frame.*; import com.oracle.truffle.api.nodes.*; import com.oracle.truffle.api.source.SourceSection; +import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.env.REnvironment; import com.oracle.truffle.r.runtime.env.REnvironment.PutException; import com.oracle.truffle.r.runtime.nodes.RSourceSectionNode; +import com.oracle.truffle.r.runtime.nodes.RSyntaxElement; +import com.oracle.truffle.r.runtime.nodes.RSyntaxFunction; @NodeInfo(shortName = "query", description = "The node representing a query") -public final class RQIRWrapperNode extends RSourceSectionNode { +public final class RQIRWrapperNode extends RSourceSectionNode implements RSyntaxFunction { // The unique identifier of the query public final int id; @@ -60,4 +63,24 @@ public final class RQIRWrapperNode extends RSourceSectionNode { } return res; } + + @Override + public ArgumentsSignature getSyntaxSignature() { + return null; + } + + @Override + public RSyntaxElement[] getSyntaxArgumentDefaults() { + return null; + } + + @Override + public RSyntaxElement getSyntaxBody() { + return null; + } + + @Override + public String getSyntaxDebugName() { + return null; + } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java index c3324c3283..fdcc9ab41b 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryNode.java @@ -26,11 +26,14 @@ import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.NodeInfo; import com.oracle.truffle.api.object.DynamicObject; import com.oracle.truffle.api.source.SourceSection; +import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.nodes.RSourceSectionNode; +import com.oracle.truffle.r.runtime.nodes.RSyntaxElement; +import com.oracle.truffle.r.runtime.nodes.RSyntaxFunction; import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; @NodeInfo(shortName = "Query") -public abstract class RQueryNode extends RSourceSectionNode implements RSyntaxNode { +public abstract class RQueryNode extends RSourceSectionNode implements RSyntaxNode, RSyntaxFunction { public RQueryNode(final SourceSection src) { super(src); } @@ -39,4 +42,24 @@ public abstract class RQueryNode extends RSourceSectionNode implements RSyntaxNo public final DynamicObject execute(final VirtualFrame frame) { throw new RuntimeException("We should not execute a RQueryNode directly."); } + + @Override + public ArgumentsSignature getSyntaxSignature() { + return null; + } + + @Override + public RSyntaxElement[] getSyntaxArgumentDefaults() { + return null; + } + + @Override + public RSyntaxElement getSyntaxBody() { + return null; + } + + @Override + public String getSyntaxDebugName() { + return null; + } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java index 1eb02abccf..cfe15724b7 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/query/RQueryVisitor.java @@ -66,7 +66,7 @@ public final class RQueryVisitor implements RSyntaxNodeVisitor<RSyntaxNode> { @Override public final RSyntaxNode visit(final FunctionExpressionNode fun) { - return FunctionExpressionNode.create(fun.getSourceSection(), fun.getCallTarget()); + return fun; } @Override diff --git a/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g b/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g index 81db5e2270..ab20ce2a42 100644 --- a/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g +++ b/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g @@ -186,7 +186,7 @@ root_function [String name] returns [RootCallTarget v] throw RInternalError.shouldNotReachHere("not at EOF after parsing deserialized function"); } } - : n_ op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.rootFunction(src($op, last()), params, $body.v, name); } + : n_ op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.rootFunction(src($op, last()), params, builder.handleQueries($body.v), name); } ; statement returns [T v] @@ -264,7 +264,7 @@ repeat_expr returns [T v] function [T assignedTo] returns [T v] @init { List<Argument<T>> params = new ArrayList<>(); } - : op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.function(src($op, last()), params, $body.v, assignedTo); } + : op=FUNCTION n_ LPAR n_ (par_decl[params] (n_ COMMA n_ par_decl[params])* n_)? RPAR n_ body=expr_or_assign { $v = builder.function(src($op, last()), params, builder.handleQueries($body.v), assignedTo); } ; par_decl [List<Argument<T>> l] diff --git a/com.oracle.truffle.r.test/tests/pgsql/QueryNested.R b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.R new file mode 100644 index 0000000000..27e8b7aaaf --- /dev/null +++ b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.R @@ -0,0 +1,25 @@ +# Returns the exchange rate between rfrom and rto +getRate = function(rfrom, rto) +{ + change = new.tableRef("change", "PostgreSQL", "postgre.config", "public") + rate = query.force(where(function (r) r$cfrom == rfrom && r$cto == rto, + from(change))) + if (rfrom == rto) 1 else rate$change +} + +# Returns the names of employees earning at least minSalary in the curr +# currency +atLeast = function(minSalary, curr) +{ + emp = new.tableRef("emp", "PostgreSQL", "postgre.config", "public") + select(function (e) { r = new.env() + r$name = e$ename + r }, + where(function (e) e$sal >= minSalary * getRate("USD", curr), + from(emp))) +} + +richUSPeople = atLeast(2000, "USD") +richEURPeople = atLeast(2000, "EUR") +print(query.force(richUSPeople)) +print(query.force(richEURPeople)) diff --git a/com.oracle.truffle.r.test/tests/pgsql/QueryNested.out b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.out new file mode 100644 index 0000000000..53da7150c8 --- /dev/null +++ b/com.oracle.truffle.r.test/tests/pgsql/QueryNested.out @@ -0,0 +1,18 @@ + name +1 SMITH +2 ALLEN +3 WARD +4 JONES +5 SCOTT +6 ADAMS +7 MILLER + name +1 SMITH +2 ALLEN +3 WARD +4 JONES +5 SCOTT +6 ADAMS +7 FORD +8 MILLER + -- GitLab