Skip to content

Commit 323afbd

Browse files
committed
Fix inconsistency in class val parameter selection symbol
A distinct TermSymbol was previously generated and used instead of the correct one. Also fix the IR printer and expose a lifter issue.
1 parent 165dfbb commit 323afbd

25 files changed

Lines changed: 265 additions & 78 deletions

hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import mlscript.utils.*, shorthands.*
77
import hkmc2.utils.*
88
import hkmc2.utils.SymbolSubst
99

10-
import syntax.{Literal, Tree, ParamBind}
10+
import syntax.{Literal, Tree}
1111
import semantics.*
1212
import semantics.Elaborator.{Ctx, ctx}
1313
import semantics.Elaborator.State

hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import hkmc2.utils.*
1111
import hkmc2.utils.SymbolSubst
1212
import hkmc2.Message.MessageContext
1313

14-
import syntax.{Literal, Tree, ParamBind}
14+
import syntax.{Literal, Tree}
1515
import semantics.*
1616
import semantics.Elaborator.ctx
1717
import semantics.Elaborator.State

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
258258
subTerm(bod)(r =>
259259
Define(ValDefn(td.tsym, td.sym, r)(cfgOverride),
260260
blockImpl(stats, res)))
261-
case syntax.LetBind | syntax.ParamBind | syntax.HandlerBind => fail:
261+
case syntax.LetBind | syntax.HandlerBind => fail:
262262
ErrorReport(
263263
msg"Unexpected declaration kind '${td.k.str}' in lowering" -> td.toLoc :: Nil,
264264
source = Diagnostic.Source.Compilation)

hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,26 +106,25 @@ class Printer(using Raise, ShowCfg, SymbolPrinter, Config):
106106
then doc""
107107
else doc" " :: braced(doc"${docPrivFlds}${docPubFlds}${docCtor}${docMethods}")
108108

109+
def printParamLists(paramss: Ls[ParamList])(using Scope): Document =
110+
doc"${paramss.map(_.params.map(x => scope.allocateName(x.sym)).mkDocument("(", ", ", ")")).mkDocument("")}"
111+
109112
def print(defn: Defn)(using Scope): Document = defn match
110-
case FunDefn(own, sym, dSym, params, body) =>
113+
case FunDefn(own, sym, dSym, paramss, body) =>
111114
scope.nest.givenIn:
112-
val docParams = doc"${
113-
params.map(_.params.map(x => scope.allocateName(x.sym)).mkDocument("(", ", ", ")")).mkDocument("")}"
115+
val docParams = printParamLists(paramss)
114116
val docBody = print(body)
115117
doc"fun ${print(dSym)}${docParams} ${bracedbk(docBody)}"
116118
case ValDefn(tsym, sym, rhs) =>
117119
doc"val ${print(tsym)} = ${print(rhs)}"
118120
case ClsLikeDefn(own, isym, sym, ctorSym, k, paramsOpt, auxParams, parentSym, methods,
119121
privateFields, publicFields, preCtor, ctor, mod, bufferable)
120122
=> scope.nest.givenIn:
121-
val clsParams = paramsOpt.fold(Nil)(_.paramSyms)
122-
val auxClsParams = auxParams.flatMap(_.paramSyms)
123-
val ctorParams = (clsParams ++ auxClsParams).map(p => scope.allocateName(p))
124-
val docCtorParams = if clsParams.isEmpty then doc"" else doc"(${ctorParams.mkDocument(", ")})"
123+
val ctorParams = printParamLists(paramsOpt.toList ::: auxParams)
125124
val docStaged = if isym.defn.forall(_.hasStagedModifier.isEmpty) then doc"" else doc"staged "
126125
val docBody = print(privateFields, publicFields, methods, S(preCtor), ctor, ctorSym)
127126
val clsType = k.str
128-
val docCls = doc"${docStaged}${clsType} ${print(isym)}${docCtorParams}${docBody}"
127+
val docCls = doc"${docStaged}${clsType} ${print(isym)}${ctorParams}${docBody}"
129128
val docModule = mod match
130129
case Some(mod) =>
131130
val docStaged = if mod.isym.defn.forall(_.hasStagedModifier.isEmpty) then doc"" else doc"staged "

hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Analyze.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import mlscript.utils.*, shorthands.*
77
import semantics.*
88
import syntax.Tree
99
import scala.collection.mutable.{Set as MutSet, Map as MutMap, LinkedHashMap, LinkedHashSet}
10-
import hkmc2.syntax.{ImmutVal, MutVal, LetBind, HandlerBind, ParamBind, Fun, Ins}
10+
import hkmc2.syntax.{ImmutVal, MutVal, LetBind, HandlerBind, Fun, Ins}
1111

1212
type StratVarId = Uid[StratVar]
1313

@@ -621,7 +621,9 @@ class DeforestConstraintsCollector(val preAnalyzer: DeforestPreAnalyzer):
621621
val argsStrat = args.map:
622622
case Arg(_, a) => processResult(a)
623623
ctor match
624-
case cls: ClassSymbol => new Ctor(c.uid, instId)(ctor, cls.tree.clsParams.zip(argsStrat))
624+
case cls: ClassSymbol => new Ctor(c.uid, instId)(ctor,
625+
cls.tree.clsParams.headOption.getOrElse(Nil)//FIXME? case when there are only aux parameter lists
626+
.zip(argsStrat))
625627
case _: ModuleOrObjectSymbol => new Ctor(c.uid, instId)(ctor, Nil)
626628
case tupSize: Int => new Ctor(c.uid, instId)(tupSize, (0 until tupSize).zip(argsStrat).toList)
627629
case Call(fun, args) => handleCallLike(fun, args)
@@ -643,7 +645,7 @@ class DeforestConstraintsCollector(val preAnalyzer: DeforestPreAnalyzer):
643645
case s: TermSymbol =>
644646
// only parambind and let/vals in modules are handled here
645647
s.k match
646-
case ParamBind =>
648+
case _ if s.decl.exists(_.isInstanceOf[Param]) =>
647649
val obj = processResult(sel.qual)
648650
val selRes = freshVar("sel_res", cc.forFunGroup)
649651
cc.constrain(

hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Deforest.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import syntax.Tree
88
import utils.*
99
import mlscript.utils.*, shorthands.*
1010
import scala.collection.mutable
11-
import hkmc2.syntax.{ImmutVal, MutVal, LetBind, HandlerBind, ParamBind, Fun, Ins}
11+
import hkmc2.syntax.{ImmutVal, MutVal, LetBind, HandlerBind, Fun, Ins}
1212

1313
case class ImportedInfo(seeThroughMods: Ls[ClsLikeBody])
1414

@@ -35,15 +35,17 @@ object DeforestableSelect:
3535
val tSym = sSym.asTrm.get
3636
tSym.k match
3737
case (Ins | HandlerBind | MutVal) => None
38-
case ParamBind => Some(tSym)
38+
case _ if tSym.decl.exists(_.isInstanceOf[Param]) => Some(tSym)
3939
case ImmutVal =>
4040
// this can be a selection from module or a class
4141
tSym.owner.flatMap:
4242
// if selecting from a class, it is only
4343
// handleable if the selected field is
4444
// one of the cls params, because in deforestation
4545
// the known field information of a ctor is only its args
46-
case cls: ClassSymbol => cls.tree.clsParams.find(_.id == tSym.id)
46+
case cls: ClassSymbol => cls.tree.clsParams.headOption
47+
.getOrElse(Nil)//FIXME? case when there are only aux parameter lists
48+
.find(_.id == tSym.id)
4749
case mod: ModuleOrObjectSymbol => Some(tSym)
4850
case _ => None
4951
case LetBind =>

hkmc2/shared/src/main/scala/hkmc2/codegen/deforest/Rewrite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import mlscript.utils.*, shorthands.*
77
import semantics.*
88
import syntax.Tree
99
import scala.collection.mutable.{Set as MutSet, Map as MutMap, LinkedHashMap, Buffer}
10-
import hkmc2.syntax.{ImmutVal, MutVal, LetBind, HandlerBind, ParamBind, Fun, Ins}
10+
import hkmc2.syntax.{ImmutVal, MutVal, LetBind, HandlerBind, Val, Fun, Ins}
1111

1212

1313

@@ -398,7 +398,7 @@ class DeforestRewriter(val solver: DeforestConstrainSolver)(using Raise):
398398
)
399399
case s@DeforestableSelect(sym: TermSymbol) =>
400400
if branchSelSyms.isDefinedAt(s.uid.toCtorDtorId) then
401-
assert(sym.k is ParamBind)
401+
assert(sym.k.isInstanceOf[Val])
402402
k(Value.Ref(branchSelSyms(s.uid.toCtorDtorId)))
403403
else if solver.finalDtorSrcs.contains(s.uid.toCtorDtorId) then
404404
applyPath(s.qual)(k)

hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
921921
case _ => lastWords(s"Cannot select field from non-struct type: ${structInfo.compType.toWat}")
922922
val fieldIdx = symToField.get(sym)
923923
.orElse:
924-
// Workaround: TermSymbols are not correctly resolved, so match the fields by name instead
925924
sym match
926-
case trmSym: TermSymbol if trmSym.owner.flatMap(_.asBlkMember).exists(_ == thisSym) =>
927-
symToField.find((fieldSym, _) => fieldSym.nme == sym.nme).map((_, v) => v)
928925
case memSym: MemberSymbol if fieldOwner(memSym).contains(thisSym) =>
929926
symToField.find((fieldSym, _) => fieldSym.nme == sym.nme).map((_, v) => v)
930927
case _ => N

hkmc2/shared/src/main/scala/hkmc2/package.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
package hkmc2
22

3+
import sourcecode.{Line, FileName}
4+
5+
import mlscript.utils.*, shorthands.*
6+
import hkmc2.utils.*
7+
import hkmc2.Message.MessageContext
8+
39

410
extension [A](a: A)
511
infix inline def givenIn[R](inline k: A ?=> R) = k(using a)
@@ -11,3 +17,11 @@ extension [A](a: A)
1117
val identifierPattern: scala.util.matching.Regex = "^[A-Za-z_$][A-Za-z0-9_$]*$".r
1218

1319

20+
def softAssert(cond: Boolean, msg: => Str = "")(using Line, FileName, Raise): Unit =
21+
if !cond then
22+
raise:
23+
InternalError(
24+
msg"Compiler reached an unexpected state at '${summon[FileName].value}:${summon[Line].value}'${if msg == "" then "" else s": $msg"}" -> N
25+
:: msg"The compilation result may be incorrect." -> N
26+
:: Nil)
27+

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,8 +1299,12 @@ extends Importer with ucs.SplitElaborator:
12991299
res
13001300

13011301
def withFields(using Ctx)(fn: (Ctx) ?=> (Term.Blk, Ctx)): (Term.Blk, Ctx) =
1302-
val fields: Ls[Statement] = pss.flatMap: ps =>
1303-
ps.params.flatMap: p =>
1302+
softAssert(pss.sizeCompare(td.clsParams) === 0,
1303+
s"mismatched parameter list numbers ${pss} vs ${td.clsParams}")
1304+
val fields: Ls[Statement] = pss.zip(td.clsParams).flatMap: (ps, cps) =>
1305+
softAssert(ps.params.sizeCompare(cps) === 0,
1306+
s"mismatched param list lengths ${ps.params} vs ${cps}")
1307+
ps.params.zip(cps).flatMap: (p, cp) =>
13041308
// For class-like types, "desugar" the parameters into additional class fields.
13051309

13061310
val owner = td.symbol match
@@ -1313,7 +1317,8 @@ extends Importer with ucs.SplitElaborator:
13131317
then
13141318
val k = if p.flags.mut then MutVal else ImmutVal
13151319
val fsym = BlockMemberSymbol(p.sym.nme, Nil)
1316-
val tsym = TermSymbol(k, owner, p.sym.id) // TODO?
1320+
val tsym = cp
1321+
cp.decl = S(p)
13171322
val fdef = TermDefinition(
13181323
k,
13191324
fsym,

0 commit comments

Comments
 (0)