Skip to content
Snippets Groups Projects
Commit b2a938f9 authored by Julien Lopez's avatar Julien Lopez
Browse files

Fixes and refactoring in type system

parent 2e9416ab
No related branches found
No related tags found
No related merge requests found
package qir.driver.sql;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import qir.ast.QIRIf;
import qir.ast.data.QIRLcons;
import qir.ast.data.QIRLdestr;
import qir.ast.data.QIRLnil;
import qir.ast.data.QIRRcons;
import qir.ast.data.QIRRdestr;
import qir.ast.data.QIRRnil;
import qir.ast.data.QIRTable;
import qir.ast.operator.QIRFilter;
import qir.ast.operator.QIRGroupBy;
import qir.ast.operator.QIRJoin;
import qir.ast.operator.QIRLeftJoin;
......@@ -22,37 +17,33 @@ import qir.ast.operator.QIRScan;
import qir.ast.operator.QIRSortBy;
import qir.types.QIRBooleanType;
import qir.types.QIRConstantType;
import qir.types.QIRFunctionType;
import qir.types.QIRListType;
import qir.types.QIRRecordType;
import qir.types.QIRSomeType;
import qir.types.QIRType;
import qir.typing.QIRDefaultTypeSystemVisitor;
import qir.typing.QIRTypeErrorException;
public class QIRSQLTypeSystemVisitor extends QIRDefaultTypeSystemVisitor {
private static final QIRRecordType anyRelationalRecord = QIRRecordType.anyRestrictedTo(QIRConstantType.ANY);
private static final QIRListType anyRelationalTable = new QIRListType(anyRelationalRecord);
private QIRSQLTypeSystemVisitor(final QIRRecordType anyRelationalRecord) {
super(anyRelationalRecord, new QIRListType(anyRelationalRecord));
}
@Override
public final QIRType visit(final QIRProject qirProject) {
final QIRListType childType = expectIfSubtype(qirProject.getChild().accept(this), anyRelationalTable);
final QIRFunctionType formatterType = expectIfSubtype(qirProject.getFormatter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, anyRelationalRecord));
public QIRSQLTypeSystemVisitor() {
this(QIRRecordType.anyRestrictedTo(QIRConstantType.ANY));
}
return new QIRListType(formatterType.getReturnType());
public QIRSQLTypeSystemVisitor(final QIRRecordType anyRecordType, final QIRListType anyListType) {
super(anyRecordType, anyListType);
}
@Override
public final QIRType visit(final QIRScan qirScan) {
return expectIfSubtype(qirScan.getTable().accept(this), anyRelationalTable);
public final QIRType visit(final QIRProject qirProject) {
return visit(qirProject, anyRecordType);
}
@Override
public final QIRType visit(final QIRFilter qirFilter) {
final QIRListType childType = expectIfSubtype(qirFilter.getChild().accept(this), anyRelationalTable);
final QIRType filterArgument = expectIfSubtype(qirFilter.getFilter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRBooleanType.getInstance())).getDomain()[0];
return new QIRListType(filterArgument);
public final QIRType visit(final QIRScan qirScan) {
return expectIfSubtype(qirScan.getTable().accept(this), anyListType);
}
@Override
......@@ -104,39 +95,13 @@ public class QIRSQLTypeSystemVisitor extends QIRDefaultTypeSystemVisitor {
throw new QIRTypeErrorException(this.getClass(), qirLnil);
}
@Override
public final QIRType visit(final QIRLcons qirLcons) {
final QIRListType tailType = expectIfSubtype(qirLcons.getTail().accept(this), anyRelationalTable);
return new QIRListType(expectIfSubtype(qirLcons.getValue().accept(this), tailType.getElementType()));
}
@Override
public final QIRType visit(final QIRLdestr qirLdestr) {
throw new QIRTypeErrorException(this.getClass(), qirLdestr);
}
@Override
public final QIRType visit(final QIRRnil qirRnil) {
return QIRRecordType.ANY;
}
@Override
public final QIRType visit(final QIRRcons qirRcons) {
final QIRRecordType tailType = expectIfSubtype(qirRcons.getTail().accept(this), anyRelationalRecord);
final Map<String, QIRType> fieldTypes = new HashMap<>();
tailType.getFieldTypes().entrySet().forEach(e -> fieldTypes.put(e.getKey(), e.getValue()));
fieldTypes.put(qirRcons.getId(), qirRcons.getValue().accept(this));
return new QIRRecordType(fieldTypes);
}
@Override
public final QIRType visit(final QIRRdestr qirRdestr) {
final String colName = qirRdestr.getColName();
final Map<String, QIRType> expectedRecordFields = new HashMap<>();
expectedRecordFields.put(colName, new QIRSomeType());
return expectIfSubtype(qirRdestr.getRecord().accept(this), new QIRRecordType(expectedRecordFields, QIRConstantType.ANY)).getFieldTypes().get(colName);
return visit(qirRdestr, Optional.of(QIRConstantType.ANY));
}
}
......@@ -13,13 +13,16 @@ public class QIRRecordType extends QIRType {
private Optional<QIRType> globalRestriction;
public QIRRecordType(final Map<String, QIRType> fieldTypes) {
this.fieldTypes = fieldTypes;
this.globalRestriction = Optional.empty();
this(fieldTypes, Optional.empty());
}
public QIRRecordType(final Map<String, QIRType> fieldTypes, final QIRType globalRestriction) {
this(fieldTypes, Optional.of(globalRestriction));
}
public QIRRecordType(final Map<String, QIRType> fieldTypes, final Optional<QIRType> globalRestriction) {
this.fieldTypes = fieldTypes;
this.globalRestriction = Optional.of(globalRestriction);
this.globalRestriction = globalRestriction;
}
public static final QIRRecordType ANY = new QIRRecordType(new HashMap<>());
......
......@@ -2,6 +2,7 @@ package qir.typing;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import qir.ast.QIRApply;
import qir.ast.QIRDBNode;
......@@ -58,22 +59,34 @@ import qir.types.QIRStringType;
import qir.types.QIRType;
public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor {
public QIRType visit(final QIRProject qirProject) {
final QIRListType childType = expectIfSubtype(qirProject.getChild().accept(this), QIRListType.ANY);
final QIRFunctionType formatterType = expectIfSubtype(qirProject.getFormatter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRAnyType.getInstance()));
public QIRDefaultTypeSystemVisitor() {
super(QIRRecordType.ANY, QIRListType.ANY);
}
public QIRDefaultTypeSystemVisitor(final QIRRecordType anyRecordType, final QIRListType anyListType) {
super(anyRecordType, anyListType);
}
protected QIRType visit(final QIRProject qirProject, final QIRType expectedFormatterReturnType) {
final QIRListType childType = expectIfSubtype(qirProject.getChild().accept(this), anyListType);
final QIRFunctionType formatterType = expectIfSubtype(qirProject.getFormatter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, expectedFormatterReturnType));
return new QIRListType(formatterType.getReturnType());
}
public QIRType visit(final QIRProject qirProject) {
return visit(qirProject, QIRAnyType.getInstance());
}
public QIRType visit(final QIRScan qirScan) {
throw new QIRTypeErrorException(this.getClass(), qirScan);
}
public QIRType visit(final QIRFilter qirFilter) {
final QIRListType childType = expectIfSubtype(qirFilter.getChild().accept(this), QIRListType.ANY);
final QIRListType childType = expectIfSubtype(qirFilter.getChild().accept(this), anyListType);
final QIRType filterArgument = expectIfSubtype(qirFilter.getFilter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRBooleanType.getInstance())).getDomain()[0];
checkSubtype(qirFilter.getFilter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRBooleanType.getInstance()));
return new QIRListType(childType.getElementType());
return new QIRListType(filterArgument);
}
public QIRType visit(final QIRGroupBy qirGroupBy) {
......@@ -116,13 +129,22 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor {
public QIRType visit(final QIRLambda qirLambda) {
final QIRType varType = new QIRSomeType();
env.put(qirLambda.getVar().id, varType);
return new QIRFunctionType(new QIRType[]{varType}, qirLambda.getBody().accept(this));
final String varId = qirLambda.getVar().id;
final QIRType funType;
final Optional<QIRType> save = env.containsKey(varId) ? Optional.of(env.get(varId)) : Optional.empty();
env.put(varId, varType);
funType = new QIRFunctionType(new QIRType[]{varType}, qirLambda.getBody().accept(this));
if (save.isPresent())
env.put(varId, save.get());
else
env.remove(varId);
return funType;
}
public QIRType visit(final QIRApply qirApply) {
final QIRType rightType = qirApply.getRight() != null ? qirApply.getRight().accept(this) : QIRAnyType.getInstance();
return expectIfSubtype(qirApply.getLeft().accept(this), new QIRFunctionType(new QIRType[]{rightType}, QIRAnyType.getInstance())).getReturnType();
return expectIfSubtype(qirApply.getLeft().accept(this),
new QIRFunctionType(new QIRType[]{qirApply.getRight() != null ? qirApply.getRight().accept(this) : QIRAnyType.getInstance()}, QIRAnyType.getInstance())).getReturnType();
}
public QIRType visit(final QIRIf qirIf) {
......@@ -191,17 +213,17 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor {
}
public QIRType visit(final QIRLnil qirLnil) {
return QIRListType.ANY;
return anyListType;
}
public QIRType visit(final QIRLcons qirLcons) {
final QIRListType tailType = expectIfSubtype(qirLcons.getTail().accept(this), QIRListType.ANY);
final QIRListType tailType = expectIfSubtype(qirLcons.getTail().accept(this), anyListType);
return new QIRListType(expectIfSubtype(qirLcons.getValue().accept(this), tailType.getElementType()));
}
public QIRType visit(final QIRLdestr qirLdestr) {
final QIRListType listType = expectIfSubtype(qirLdestr.getList().accept(this), QIRListType.ANY);
final QIRListType listType = expectIfSubtype(qirLdestr.getList().accept(this), anyListType);
final QIRType returnType = qirLdestr.getIfEmpty().accept(this);
checkSubtype(qirLdestr.getHandler().accept(this), new QIRFunctionType(new QIRType[]{listType.getElementType(), listType}, returnType));
......@@ -209,11 +231,11 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor {
}
public QIRType visit(final QIRRnil qirRnil) {
return QIRRecordType.ANY;
return anyRecordType;
}
public QIRType visit(final QIRRcons qirRcons) {
final QIRRecordType tailType = expectIfSubtype(qirRcons.getTail().accept(this), QIRRecordType.ANY);
final QIRRecordType tailType = expectIfSubtype(qirRcons.getTail().accept(this), anyRecordType);
final Map<String, QIRType> fieldTypes = new HashMap<>();
tailType.getFieldTypes().entrySet().forEach(e -> fieldTypes.put(e.getKey(), e.getValue()));
......@@ -221,12 +243,17 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor {
return new QIRRecordType(fieldTypes);
}
public QIRType visit(final QIRRdestr qirRdestr) {
protected QIRType visit(final QIRRdestr qirRdestr, final Optional<QIRType> globalRecordRestriction) {
final String colName = qirRdestr.getColName();
final Map<String, QIRType> expectedRecordFields = new HashMap<>();
final QIRRecordType expectedRecordType = new QIRRecordType(expectedRecordFields, globalRecordRestriction);
expectedRecordFields.put(colName, new QIRSomeType());
return expectIfSubtype(qirRdestr.getRecord().accept(this), new QIRRecordType(expectedRecordFields)).getFieldTypes().get(colName);
return expectIfSubtype(qirRdestr.getRecord().accept(this), expectedRecordType).getFieldTypes().get(colName);
}
public QIRType visit(final QIRRdestr qirRdestr) {
return visit(qirRdestr, Optional.empty());
}
public QIRType visit(final QIRString qirString) {
......
......@@ -4,11 +4,20 @@ import java.util.HashMap;
import java.util.Map;
import qir.driver.IQIRVisitor;
import qir.types.QIRListType;
import qir.types.QIRRecordType;
import qir.types.QIRSomeType;
import qir.types.QIRType;
public abstract class QIRTypeSystemVisitor implements IQIRVisitor<QIRType> {
protected final Map<String, QIRType> env = new HashMap<>();
protected final QIRRecordType anyRecordType;
protected final QIRListType anyListType;
public QIRTypeSystemVisitor(final QIRRecordType anyRecordType, final QIRListType anyListType) {
this.anyRecordType = anyRecordType;
this.anyListType = anyListType;
}
protected void checkSubtype(final QIRType actual, final QIRType expected) {
if (!actual.isSubtypeOf(expected))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment