diff --git a/src/qir/driver/sql/QIRSQLTypeSystemVisitor.java b/src/qir/driver/sql/QIRSQLTypeSystemVisitor.java index 154274fafbe7c20234ac4b41d4d2344c0fed55a6..8a096006302f424706093311bba1a101f63fe0f9 100644 --- a/src/qir/driver/sql/QIRSQLTypeSystemVisitor.java +++ b/src/qir/driver/sql/QIRSQLTypeSystemVisitor.java @@ -1,17 +1,12 @@ package qir.driver.sql; -import java.util.HashMap; -import java.util.Map; +import java.util.Optional; import qir.ast.QIRIf; -import qir.ast.data.QIRLcons; import qir.ast.data.QIRLdestr; import qir.ast.data.QIRLnil; -import qir.ast.data.QIRRcons; import qir.ast.data.QIRRdestr; -import qir.ast.data.QIRRnil; import qir.ast.data.QIRTable; -import qir.ast.operator.QIRFilter; import qir.ast.operator.QIRGroupBy; import qir.ast.operator.QIRJoin; import qir.ast.operator.QIRLeftJoin; @@ -22,37 +17,33 @@ import qir.ast.operator.QIRScan; import qir.ast.operator.QIRSortBy; import qir.types.QIRBooleanType; import qir.types.QIRConstantType; -import qir.types.QIRFunctionType; import qir.types.QIRListType; import qir.types.QIRRecordType; -import qir.types.QIRSomeType; import qir.types.QIRType; import qir.typing.QIRDefaultTypeSystemVisitor; import qir.typing.QIRTypeErrorException; public class QIRSQLTypeSystemVisitor extends QIRDefaultTypeSystemVisitor { - private static final QIRRecordType anyRelationalRecord = QIRRecordType.anyRestrictedTo(QIRConstantType.ANY); - private static final QIRListType anyRelationalTable = new QIRListType(anyRelationalRecord); + private QIRSQLTypeSystemVisitor(final QIRRecordType anyRelationalRecord) { + super(anyRelationalRecord, new QIRListType(anyRelationalRecord)); + } - @Override - public final QIRType visit(final QIRProject qirProject) { - final QIRListType childType = expectIfSubtype(qirProject.getChild().accept(this), anyRelationalTable); - final QIRFunctionType formatterType = expectIfSubtype(qirProject.getFormatter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, anyRelationalRecord)); + public QIRSQLTypeSystemVisitor() { + this(QIRRecordType.anyRestrictedTo(QIRConstantType.ANY)); + } - return new QIRListType(formatterType.getReturnType()); + public QIRSQLTypeSystemVisitor(final QIRRecordType anyRecordType, final QIRListType anyListType) { + super(anyRecordType, anyListType); } @Override - public final QIRType visit(final QIRScan qirScan) { - return expectIfSubtype(qirScan.getTable().accept(this), anyRelationalTable); + public final QIRType visit(final QIRProject qirProject) { + return visit(qirProject, anyRecordType); } @Override - public final QIRType visit(final QIRFilter qirFilter) { - final QIRListType childType = expectIfSubtype(qirFilter.getChild().accept(this), anyRelationalTable); - final QIRType filterArgument = expectIfSubtype(qirFilter.getFilter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRBooleanType.getInstance())).getDomain()[0]; - - return new QIRListType(filterArgument); + public final QIRType visit(final QIRScan qirScan) { + return expectIfSubtype(qirScan.getTable().accept(this), anyListType); } @Override @@ -104,39 +95,13 @@ public class QIRSQLTypeSystemVisitor extends QIRDefaultTypeSystemVisitor { throw new QIRTypeErrorException(this.getClass(), qirLnil); } - @Override - public final QIRType visit(final QIRLcons qirLcons) { - final QIRListType tailType = expectIfSubtype(qirLcons.getTail().accept(this), anyRelationalTable); - - return new QIRListType(expectIfSubtype(qirLcons.getValue().accept(this), tailType.getElementType())); - } - @Override public final QIRType visit(final QIRLdestr qirLdestr) { throw new QIRTypeErrorException(this.getClass(), qirLdestr); } - @Override - public final QIRType visit(final QIRRnil qirRnil) { - return QIRRecordType.ANY; - } - - @Override - public final QIRType visit(final QIRRcons qirRcons) { - final QIRRecordType tailType = expectIfSubtype(qirRcons.getTail().accept(this), anyRelationalRecord); - - final Map<String, QIRType> fieldTypes = new HashMap<>(); - tailType.getFieldTypes().entrySet().forEach(e -> fieldTypes.put(e.getKey(), e.getValue())); - fieldTypes.put(qirRcons.getId(), qirRcons.getValue().accept(this)); - return new QIRRecordType(fieldTypes); - } - @Override public final QIRType visit(final QIRRdestr qirRdestr) { - final String colName = qirRdestr.getColName(); - final Map<String, QIRType> expectedRecordFields = new HashMap<>(); - - expectedRecordFields.put(colName, new QIRSomeType()); - return expectIfSubtype(qirRdestr.getRecord().accept(this), new QIRRecordType(expectedRecordFields, QIRConstantType.ANY)).getFieldTypes().get(colName); + return visit(qirRdestr, Optional.of(QIRConstantType.ANY)); } } diff --git a/src/qir/types/QIRRecordType.java b/src/qir/types/QIRRecordType.java index 009e4dbf0c8ba26919c98bb2d4d1472ee3c99fc8..f2ae48add0c12d3a6481745fc1b0b9d61e0783f3 100644 --- a/src/qir/types/QIRRecordType.java +++ b/src/qir/types/QIRRecordType.java @@ -13,13 +13,16 @@ public class QIRRecordType extends QIRType { private Optional<QIRType> globalRestriction; public QIRRecordType(final Map<String, QIRType> fieldTypes) { - this.fieldTypes = fieldTypes; - this.globalRestriction = Optional.empty(); + this(fieldTypes, Optional.empty()); } public QIRRecordType(final Map<String, QIRType> fieldTypes, final QIRType globalRestriction) { + this(fieldTypes, Optional.of(globalRestriction)); + } + + public QIRRecordType(final Map<String, QIRType> fieldTypes, final Optional<QIRType> globalRestriction) { this.fieldTypes = fieldTypes; - this.globalRestriction = Optional.of(globalRestriction); + this.globalRestriction = globalRestriction; } public static final QIRRecordType ANY = new QIRRecordType(new HashMap<>()); diff --git a/src/qir/typing/QIRDefaultTypeSystemVisitor.java b/src/qir/typing/QIRDefaultTypeSystemVisitor.java index 0f377177e7c35732f8190ee1f266bb017c4746ae..515804fc1b889f268dfc9df7724414d45161125e 100644 --- a/src/qir/typing/QIRDefaultTypeSystemVisitor.java +++ b/src/qir/typing/QIRDefaultTypeSystemVisitor.java @@ -2,6 +2,7 @@ package qir.typing; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import qir.ast.QIRApply; import qir.ast.QIRDBNode; @@ -58,22 +59,34 @@ import qir.types.QIRStringType; import qir.types.QIRType; public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor { - public QIRType visit(final QIRProject qirProject) { - final QIRListType childType = expectIfSubtype(qirProject.getChild().accept(this), QIRListType.ANY); - final QIRFunctionType formatterType = expectIfSubtype(qirProject.getFormatter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRAnyType.getInstance())); + public QIRDefaultTypeSystemVisitor() { + super(QIRRecordType.ANY, QIRListType.ANY); + } + + public QIRDefaultTypeSystemVisitor(final QIRRecordType anyRecordType, final QIRListType anyListType) { + super(anyRecordType, anyListType); + } + + protected QIRType visit(final QIRProject qirProject, final QIRType expectedFormatterReturnType) { + final QIRListType childType = expectIfSubtype(qirProject.getChild().accept(this), anyListType); + final QIRFunctionType formatterType = expectIfSubtype(qirProject.getFormatter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, expectedFormatterReturnType)); return new QIRListType(formatterType.getReturnType()); } + public QIRType visit(final QIRProject qirProject) { + return visit(qirProject, QIRAnyType.getInstance()); + } + public QIRType visit(final QIRScan qirScan) { throw new QIRTypeErrorException(this.getClass(), qirScan); } public QIRType visit(final QIRFilter qirFilter) { - final QIRListType childType = expectIfSubtype(qirFilter.getChild().accept(this), QIRListType.ANY); + final QIRListType childType = expectIfSubtype(qirFilter.getChild().accept(this), anyListType); + final QIRType filterArgument = expectIfSubtype(qirFilter.getFilter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRBooleanType.getInstance())).getDomain()[0]; - checkSubtype(qirFilter.getFilter().accept(this), new QIRFunctionType(new QIRType[]{childType.getElementType()}, QIRBooleanType.getInstance())); - return new QIRListType(childType.getElementType()); + return new QIRListType(filterArgument); } public QIRType visit(final QIRGroupBy qirGroupBy) { @@ -116,13 +129,22 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor { public QIRType visit(final QIRLambda qirLambda) { final QIRType varType = new QIRSomeType(); - env.put(qirLambda.getVar().id, varType); - return new QIRFunctionType(new QIRType[]{varType}, qirLambda.getBody().accept(this)); + final String varId = qirLambda.getVar().id; + final QIRType funType; + final Optional<QIRType> save = env.containsKey(varId) ? Optional.of(env.get(varId)) : Optional.empty(); + + env.put(varId, varType); + funType = new QIRFunctionType(new QIRType[]{varType}, qirLambda.getBody().accept(this)); + if (save.isPresent()) + env.put(varId, save.get()); + else + env.remove(varId); + return funType; } public QIRType visit(final QIRApply qirApply) { - final QIRType rightType = qirApply.getRight() != null ? qirApply.getRight().accept(this) : QIRAnyType.getInstance(); - return expectIfSubtype(qirApply.getLeft().accept(this), new QIRFunctionType(new QIRType[]{rightType}, QIRAnyType.getInstance())).getReturnType(); + return expectIfSubtype(qirApply.getLeft().accept(this), + new QIRFunctionType(new QIRType[]{qirApply.getRight() != null ? qirApply.getRight().accept(this) : QIRAnyType.getInstance()}, QIRAnyType.getInstance())).getReturnType(); } public QIRType visit(final QIRIf qirIf) { @@ -191,17 +213,17 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor { } public QIRType visit(final QIRLnil qirLnil) { - return QIRListType.ANY; + return anyListType; } public QIRType visit(final QIRLcons qirLcons) { - final QIRListType tailType = expectIfSubtype(qirLcons.getTail().accept(this), QIRListType.ANY); + final QIRListType tailType = expectIfSubtype(qirLcons.getTail().accept(this), anyListType); return new QIRListType(expectIfSubtype(qirLcons.getValue().accept(this), tailType.getElementType())); } public QIRType visit(final QIRLdestr qirLdestr) { - final QIRListType listType = expectIfSubtype(qirLdestr.getList().accept(this), QIRListType.ANY); + final QIRListType listType = expectIfSubtype(qirLdestr.getList().accept(this), anyListType); final QIRType returnType = qirLdestr.getIfEmpty().accept(this); checkSubtype(qirLdestr.getHandler().accept(this), new QIRFunctionType(new QIRType[]{listType.getElementType(), listType}, returnType)); @@ -209,11 +231,11 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor { } public QIRType visit(final QIRRnil qirRnil) { - return QIRRecordType.ANY; + return anyRecordType; } public QIRType visit(final QIRRcons qirRcons) { - final QIRRecordType tailType = expectIfSubtype(qirRcons.getTail().accept(this), QIRRecordType.ANY); + final QIRRecordType tailType = expectIfSubtype(qirRcons.getTail().accept(this), anyRecordType); final Map<String, QIRType> fieldTypes = new HashMap<>(); tailType.getFieldTypes().entrySet().forEach(e -> fieldTypes.put(e.getKey(), e.getValue())); @@ -221,12 +243,17 @@ public class QIRDefaultTypeSystemVisitor extends QIRTypeSystemVisitor { return new QIRRecordType(fieldTypes); } - public QIRType visit(final QIRRdestr qirRdestr) { + protected QIRType visit(final QIRRdestr qirRdestr, final Optional<QIRType> globalRecordRestriction) { final String colName = qirRdestr.getColName(); final Map<String, QIRType> expectedRecordFields = new HashMap<>(); + final QIRRecordType expectedRecordType = new QIRRecordType(expectedRecordFields, globalRecordRestriction); expectedRecordFields.put(colName, new QIRSomeType()); - return expectIfSubtype(qirRdestr.getRecord().accept(this), new QIRRecordType(expectedRecordFields)).getFieldTypes().get(colName); + return expectIfSubtype(qirRdestr.getRecord().accept(this), expectedRecordType).getFieldTypes().get(colName); + } + + public QIRType visit(final QIRRdestr qirRdestr) { + return visit(qirRdestr, Optional.empty()); } public QIRType visit(final QIRString qirString) { diff --git a/src/qir/typing/QIRTypeSystemVisitor.java b/src/qir/typing/QIRTypeSystemVisitor.java index a9e92a167c31edcc52f3061257dd1bb6b0cedf02..9e16a3a93a5c6a255a621ed72adbb3bc2d3e7863 100644 --- a/src/qir/typing/QIRTypeSystemVisitor.java +++ b/src/qir/typing/QIRTypeSystemVisitor.java @@ -4,11 +4,20 @@ import java.util.HashMap; import java.util.Map; import qir.driver.IQIRVisitor; +import qir.types.QIRListType; +import qir.types.QIRRecordType; import qir.types.QIRSomeType; import qir.types.QIRType; public abstract class QIRTypeSystemVisitor implements IQIRVisitor<QIRType> { protected final Map<String, QIRType> env = new HashMap<>(); + protected final QIRRecordType anyRecordType; + protected final QIRListType anyListType; + + public QIRTypeSystemVisitor(final QIRRecordType anyRecordType, final QIRListType anyListType) { + this.anyRecordType = anyRecordType; + this.anyListType = anyListType; + } protected void checkSubtype(final QIRType actual, final QIRType expected) { if (!actual.isSubtypeOf(expected))