diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java index 7b8013d3efc086cdc6b4479260e54af138e2684c..10e7feb656b6ef820fad51bf4ff40dd09bd80690 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java @@ -50,6 +50,7 @@ import com.oracle.truffle.r.nodes.function.SaveArgumentsNode; import com.oracle.truffle.r.nodes.function.WrapDefaultArgumentNode; import com.oracle.truffle.r.nodes.function.signature.MissingNode; import com.oracle.truffle.r.nodes.query.RFromNode; +import com.oracle.truffle.r.nodes.query.RGroupNode; import com.oracle.truffle.r.nodes.query.RQueryVisitor; import com.oracle.truffle.r.nodes.query.RSelectNode; import com.oracle.truffle.r.nodes.query.RWhereNode; @@ -169,6 +170,11 @@ public final class RASTBuilder implements RCodeBuilder<RSyntaxNode> { return new RWhereNode(source, filter.asRNode(), child.asRNode()); } + @Override + public final RSyntaxNode groupby(final SourceSection source, final RSyntaxNode filter, final RSyntaxNode child) { + return new RGroupNode(source, filter.asRNode(), child.asRNode()); + } + private static ArgumentsSignature createSignature(List<Argument<RSyntaxNode>> args) { String[] argumentNames = args.stream().map(arg -> arg.name).toArray(String[]::new); ArgumentsSignature signature = ArgumentsSignature.get(argumentNames); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java index 21f128cd563f60f1dc14e0938f031c73d150e129..24fd4f9806dde090cc882c3acef7b6ffc4dc3ff6 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/qirinterface/QIRInterface.java @@ -238,11 +238,19 @@ public final class QIRInterface { if (fun.isBuiltin()) switch (fun.getName()) { case "new.env": - return new QIRLambda(null, "new.env", new QIRVariable(null, "_", null), QIRTnil.getInstance()); + return new QIRLambda(src, "new.env", new QIRVariable(null, "_", null), QIRTnil.getInstance()); case "return": - case "(": + case "(": { final QIRVariable x = new QIRVariable(null, "x", null); - return new QIRLambda(null, "identity", x, x); + return new QIRLambda(src, "identity", x, x); + } + case "c": { + // TODO: This works only for lists with one element + final QIRVariable x = new QIRVariable(null, "x", null); + return new QIRLambda(src, "lcons", x, new QIRLcons(null, x, QIRLnil.getInstance())); + } + case "sum": + return new QIRBuiltin(src, "sum"); default: throw new RuntimeException("Unsupported value: " + value + " : " + value.getClass()); } diff --git a/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g b/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g index f89b658919646eaab8318991a9f8a16748524953..a03d0487305fba4ef9eae568bd07469a40293e15 100644 --- a/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g +++ b/com.oracle.truffle.r.parser/src/com/oracle/truffle/r/parser/R.g @@ -416,6 +416,7 @@ query returns [T v] : op=SELECT n_ LPAR n_ formatter=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.select(src($op, $end), $formatter.v, $child.v); } | op=FROM n_ LPAR n_ child=expr end=RPAR { $v = builder.from(src($op, $end), $child.v); } | op=WHERE n_ LPAR n_ filter=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.where(src($op, $end), $filter.v, $child.v); } + | op=GROUP n_ LPAR n_ group=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.groupby(src($op, $end), $group.v, $child.v); } ; number returns [T v] @@ -572,6 +573,7 @@ BREAK : 'break' ; SELECT : 'select' ; FROM : 'from' ; WHERE : 'where' ; +GROUP : 'groupby' ; WS : ('\u0009'|'\u0020'|'\u00A0') { $channel=HIDDEN; } ; NEWLINE : LINE_BREAK { if(incompleteNesting > 0) $channel=HIDDEN; } ; diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RCodeBuilder.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RCodeBuilder.java index 45d89f4abd67d72f830de7d19e640eb30fb53735..9d2bdff63a82add58ba4ec2ffd0c14d4075489f9 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RCodeBuilder.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RCodeBuilder.java @@ -114,6 +114,11 @@ public interface RCodeBuilder<T> { */ T where(final SourceSection source, final T filter, final T child); + /** + * Creates a where query. + */ + T groupby(final SourceSection source, final T group, final T child); + /** * Creates a constant, the value is expected to be one of FastR's scalar types (byte, int, * double, RComplex, String, RNull). diff --git a/com.oracle.truffle.r.test/tests/pgsql/QueryGroup.R b/com.oracle.truffle.r.test/tests/pgsql/QueryGroup.R new file mode 100644 index 0000000000000000000000000000000000000000..6be158169238e72c05c2957c4d54800a1966cdc5 --- /dev/null +++ b/com.oracle.truffle.r.test/tests/pgsql/QueryGroup.R @@ -0,0 +1,6 @@ +emp = new.table("emp", "PostgreSQL", "postgre.config", "public") +q = groupby(function (x) { c(x$deptno) }, + select(function (x) { res = new.env(); res$deptno = x$deptno; res$sum = sum(x$sal); res }, + from(emp))) +results = query.force(q) +for (r in results) { print (sprintf("%d, %d", r$deptno, r$sum)) }