@@ -1165,13 +1165,7 @@ extension StmtTypeChecker {
11651165 case let . columns( columnsDefs, constraints, options) :
11661166 var columns : Columns = [ : ]
11671167 for (name, def) in columnsDefs {
1168- let type = typeFor (
1169- column: def,
1170- tableColumns: columns,
1171- tableName: createTable. name. value
1172- )
1173- let isGenerated = def. constraints. contains { $0. isGenerated }
1174- let column = Column ( type: type, isGenerated: isGenerated)
1168+ let column = column ( for: def, columns: columns, tableName: createTable. name. value)
11751169 columns. append ( column, for: name. value)
11761170 }
11771171
@@ -1200,6 +1194,28 @@ extension StmtTypeChecker {
12001194 }
12011195 }
12021196
1197+ mutating func column(
1198+ for columnDef: ColumnDefSyntax ,
1199+ columns: Columns ,
1200+ tableName: Substring
1201+ ) -> Column {
1202+ let type = typeFor (
1203+ column: columnDef,
1204+ tableColumns: columns,
1205+ tableName: tableName
1206+ )
1207+
1208+ var isGenerated = false
1209+ var hasDefault = false
1210+
1211+ for constraint in columnDef. constraints {
1212+ isGenerated = constraint. isGenerated || isGenerated
1213+ hasDefault = constraint. isDefault || isGenerated
1214+ }
1215+
1216+ return Column ( type: type, hasDefault: hasDefault, isGenerated: isGenerated)
1217+ }
1218+
12031219 mutating func typeCheck( alterTable: AlterTableStmtSyntax ) {
12041220 var tableName = qualifedName ( for: alterTable. name, in: alterTable. schemaName)
12051221
@@ -1229,15 +1245,9 @@ extension StmtTypeChecker {
12291245 table. name = tableName
12301246 case let . renameColumn( oldName, newName) :
12311247 table. columns. rename ( oldName. value, to: newName. value)
1232- case let . addColumn( column) :
1233- let newType = typeFor (
1234- column: column,
1235- tableColumns: table. columns,
1236- tableName: table. name. name
1237- )
1238- let isGenerated = column. constraints. contains { $0. isGenerated }
1239- let newColumn = Column ( type: newType, isGenerated: isGenerated)
1240- table. columns. append ( newColumn, for: column. name. value)
1248+ case let . addColumn( def) :
1249+ let column = column ( for: def, columns: table. columns, tableName: table. name. name)
1250+ table. columns. append ( column, for: def. name. value)
12411251 case let . dropColumn( column) :
12421252 table. columns = Columns ( table. columns. filter { $0. key != column. value } )
12431253 }
0 commit comments