Skip to content
Snippets Groups Projects
Commit 939460e7 authored by Julien Lopez's avatar Julien Lopez
Browse files

Add support for group by in queries and a test on it

parent b49bec0b
No related branches found
No related tags found
No related merge requests found
...@@ -50,6 +50,7 @@ import com.oracle.truffle.r.nodes.function.SaveArgumentsNode; ...@@ -50,6 +50,7 @@ import com.oracle.truffle.r.nodes.function.SaveArgumentsNode;
import com.oracle.truffle.r.nodes.function.WrapDefaultArgumentNode; import com.oracle.truffle.r.nodes.function.WrapDefaultArgumentNode;
import com.oracle.truffle.r.nodes.function.signature.MissingNode; import com.oracle.truffle.r.nodes.function.signature.MissingNode;
import com.oracle.truffle.r.nodes.query.RFromNode; import com.oracle.truffle.r.nodes.query.RFromNode;
import com.oracle.truffle.r.nodes.query.RGroupNode;
import com.oracle.truffle.r.nodes.query.RQueryVisitor; import com.oracle.truffle.r.nodes.query.RQueryVisitor;
import com.oracle.truffle.r.nodes.query.RSelectNode; import com.oracle.truffle.r.nodes.query.RSelectNode;
import com.oracle.truffle.r.nodes.query.RWhereNode; import com.oracle.truffle.r.nodes.query.RWhereNode;
...@@ -169,6 +170,11 @@ public final class RASTBuilder implements RCodeBuilder<RSyntaxNode> { ...@@ -169,6 +170,11 @@ public final class RASTBuilder implements RCodeBuilder<RSyntaxNode> {
return new RWhereNode(source, filter.asRNode(), child.asRNode()); return new RWhereNode(source, filter.asRNode(), child.asRNode());
} }
@Override
public final RSyntaxNode groupby(final SourceSection source, final RSyntaxNode filter, final RSyntaxNode child) {
return new RGroupNode(source, filter.asRNode(), child.asRNode());
}
private static ArgumentsSignature createSignature(List<Argument<RSyntaxNode>> args) { private static ArgumentsSignature createSignature(List<Argument<RSyntaxNode>> args) {
String[] argumentNames = args.stream().map(arg -> arg.name).toArray(String[]::new); String[] argumentNames = args.stream().map(arg -> arg.name).toArray(String[]::new);
ArgumentsSignature signature = ArgumentsSignature.get(argumentNames); ArgumentsSignature signature = ArgumentsSignature.get(argumentNames);
......
...@@ -238,11 +238,19 @@ public final class QIRInterface { ...@@ -238,11 +238,19 @@ public final class QIRInterface {
if (fun.isBuiltin()) if (fun.isBuiltin())
switch (fun.getName()) { switch (fun.getName()) {
case "new.env": case "new.env":
return new QIRLambda(null, "new.env", new QIRVariable(null, "_", null), QIRTnil.getInstance()); return new QIRLambda(src, "new.env", new QIRVariable(null, "_", null), QIRTnil.getInstance());
case "return": case "return":
case "(": case "(": {
final QIRVariable x = new QIRVariable(null, "x", null); final QIRVariable x = new QIRVariable(null, "x", null);
return new QIRLambda(null, "identity", x, x); return new QIRLambda(src, "identity", x, x);
}
case "c": {
// TODO: This works only for lists with one element
final QIRVariable x = new QIRVariable(null, "x", null);
return new QIRLambda(src, "lcons", x, new QIRLcons(null, x, QIRLnil.getInstance()));
}
case "sum":
return new QIRBuiltin(src, "sum");
default: default:
throw new RuntimeException("Unsupported value: " + value + " : " + value.getClass()); throw new RuntimeException("Unsupported value: " + value + " : " + value.getClass());
} }
......
...@@ -416,6 +416,7 @@ query returns [T v] ...@@ -416,6 +416,7 @@ query returns [T v]
: op=SELECT n_ LPAR n_ formatter=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.select(src($op, $end), $formatter.v, $child.v); } : op=SELECT n_ LPAR n_ formatter=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.select(src($op, $end), $formatter.v, $child.v); }
| op=FROM n_ LPAR n_ child=expr end=RPAR { $v = builder.from(src($op, $end), $child.v); } | op=FROM n_ LPAR n_ child=expr end=RPAR { $v = builder.from(src($op, $end), $child.v); }
| op=WHERE n_ LPAR n_ filter=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.where(src($op, $end), $filter.v, $child.v); } | op=WHERE n_ LPAR n_ filter=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.where(src($op, $end), $filter.v, $child.v); }
| op=GROUP n_ LPAR n_ group=expr n_ COMMA n_ child=expr end=RPAR { $v = builder.groupby(src($op, $end), $group.v, $child.v); }
; ;
number returns [T v] number returns [T v]
...@@ -572,6 +573,7 @@ BREAK : 'break' ; ...@@ -572,6 +573,7 @@ BREAK : 'break' ;
SELECT : 'select' ; SELECT : 'select' ;
FROM : 'from' ; FROM : 'from' ;
WHERE : 'where' ; WHERE : 'where' ;
GROUP : 'groupby' ;
WS : ('\u0009'|'\u0020'|'\u00A0') { $channel=HIDDEN; } ; WS : ('\u0009'|'\u0020'|'\u00A0') { $channel=HIDDEN; } ;
NEWLINE : LINE_BREAK { if(incompleteNesting > 0) $channel=HIDDEN; } ; NEWLINE : LINE_BREAK { if(incompleteNesting > 0) $channel=HIDDEN; } ;
......
...@@ -114,6 +114,11 @@ public interface RCodeBuilder<T> { ...@@ -114,6 +114,11 @@ public interface RCodeBuilder<T> {
*/ */
T where(final SourceSection source, final T filter, final T child); T where(final SourceSection source, final T filter, final T child);
/**
* Creates a where query.
*/
T groupby(final SourceSection source, final T group, final T child);
/** /**
* Creates a constant, the value is expected to be one of FastR's scalar types (byte, int, * Creates a constant, the value is expected to be one of FastR's scalar types (byte, int,
* double, RComplex, String, RNull). * double, RComplex, String, RNull).
......
emp = new.table("emp", "PostgreSQL", "postgre.config", "public")
q = groupby(function (x) { c(x$deptno) },
select(function (x) { res = new.env(); res$deptno = x$deptno; res$sum = sum(x$sal); res },
from(emp)))
results = query.force(q)
for (r in results) { print (sprintf("%d, %d", r$deptno, r$sum)) }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment