Skip to content

Commit 9f1ae29

Browse files
CAG2MarkAnsonYeungLPTK
authored
Stack safety overhaul and effect handlers debloating MK1 (hkust-taco#307)
Co-authored-by: Anson Yeung <nanolify@gmail.com> Co-authored-by: Lionel Parreaux <lionel.parreaux@gmail.com>
1 parent 6084c15 commit 9f1ae29

22 files changed

Lines changed: 856 additions & 720 deletions

shared/src/main/scala/mlscript/utils/Lazy.scala renamed to core/shared/main/scala/utils/Lazy.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class Eager[+A](val value: A) extends Box[A] {
1616

1717
class Lazy[A](thunk: => A) extends Box[A] {
1818
def isComputing = _isComputing
19+
def isEmpty: Bool = _value.isEmpty
1920
private var _isComputing = false
2021
private var _value: Opt[A] = N
2122
def get = if (_isComputing) N else S(get_!)

hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala

Lines changed: 352 additions & 101 deletions
Large diffs are not rendered by default.

hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
151151
* @param localPaths The path to access a particular local (possibly belonging to a previous function) in the current scope
152152
* @param iSymPaths The path to access a particular `innerSymbol` (possibly belonging to a previous class) in the current scope
153153
* @param replacedDefns Ignored (unlifted) definitions that have been rewritten and need to be replaced at the definition site.
154+
* @param firstClsFns Nested functions which are used as first-class functions.
154155
* @param companionMap Map from companion object symbols to the corresponding regular class symbol.
155156
*/
156157
case class LifterCtx private (
@@ -173,6 +174,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
173174
val localPaths: Map[Local, LocalPath] = Map.empty,
174175
val isymPaths: Map[InnerSymbol, LocalPath] = Map.empty,
175176
val replacedDefns: Map[BlockMemberSymbol, Defn] = Map.empty,
177+
val firstClsFns: Set[BlockMemberSymbol] = Set.empty,
176178
val companionMap: Map[InnerSymbol, InnerSymbol] = Map.empty
177179
):
178180
// gets the function to which a local belongs
@@ -197,6 +199,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
197199
def withNestedDefns(mp: Map[BlockMemberSymbol, List[Defn]]) = copy(nestedDefns = mp)
198200
def withAccesses(mp: Map[BlockMemberSymbol, AccessInfo]) = copy(accessInfo = mp)
199201
def withInScopes(mp: Map[BlockMemberSymbol, Set[BlockMemberSymbol]]) = copy(inScopeDefns = mp)
202+
def withFirstClsFns(fns: Set[BlockMemberSymbol]) = copy(firstClsFns = fns)
200203
def withCompanionMap(mp: Map[InnerSymbol, InnerSymbol]) = copy(companionMap = mp)
201204
def addFnLocals(f: FreeVars) = copy(prevFnLocals = prevFnLocals ++ f)
202205
def addClsDefn(c: ClsLikeDefn) = copy(prevClsDefns = c :: prevClsDefns)
@@ -267,7 +270,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
267270

268271
val sortedVars = cap.toArray.sortBy(_.uid).map: sym =>
269272
val id = fresh.make
270-
val nme = sym.nme + id + "$"
273+
val nme = sym.nme + "$capture$" + id
271274

272275
val ident = new Tree.Ident(nme)
273276
val varSym = VarSymbol(ident)
@@ -367,10 +370,18 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
367370
val extraDefns: List[Defn],
368371
)
369372

373+
private case class LifterMetadata(
374+
unliftable: Set[BlockMemberSymbol],
375+
modules: List[ClsLikeDefn],
376+
objects: List[ClsLikeDefn],
377+
firstClsFns: Set[BlockMemberSymbol]
378+
)
379+
370380
// d is a top-level definition
371381
// returns (ignored classes, modules, objects)
372-
def createMetadata(d: Defn, ctx: LifterCtx): (Set[BlockMemberSymbol], List[ClsLikeDefn], List[ClsLikeDefn]) =
382+
private def createMetadata(d: Defn, ctx: LifterCtx): LifterMetadata =
373383
var ignored: Set[BlockMemberSymbol] = Set.empty
384+
var firstClsFns: Set[BlockMemberSymbol] = Set.empty
374385
var unliftable: Set[BlockMemberSymbol] = Set.empty
375386
var clsSymToBms: Map[Local, BlockMemberSymbol] = Map.empty
376387
var modules: List[ClsLikeDefn] = Nil
@@ -483,6 +494,10 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
483494
applyBlock(ctor)
484495
mod.foreach(applyClsLikeBody)
485496

497+
def isFun(d: Defn) = d match
498+
case FunDefn(owner, sym, params, body) => true
499+
case _ => false
500+
486501
override def applyValue(v: Value): Unit = v match
487502
case RefOfBms(l) if clsSyms.contains(l) && !modOrObj(ctx.defns(l)) =>
488503
raise(WarningReport(
@@ -491,6 +506,9 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
491506
))
492507
ignored += l
493508
unliftable += l
509+
case RefOfBms(l) if ctx.defns.contains(l) && isFun(ctx.defns(l)) =>
510+
// naked reference to a function definition
511+
firstClsFns += l
494512
case _ => super.applyValue(v)
495513

496514
// analyze the extends graph
@@ -512,8 +530,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
512530
dfs(b)
513531
for s <- ignored do
514532
dfs(s)
515-
516-
(ignored ++ newUnliftable, modules.toList, objects.toList)
533+
534+
LifterMetadata(ignored ++ newUnliftable, modules.toList, objects.toList, firstClsFns)
517535

518536
extension (b: Block)
519537
private def floatOut(ctx: LifterCtx) =
@@ -902,7 +920,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
902920

903921
d match
904922
case f: FunDefn =>
905-
val createSym = (nme: String) =>
923+
def createSym(nme: String) =
906924
val vsym = VarSymbol(Tree.Ident(nme))
907925
(vsym, LocalPath.Sym(vsym))
908926
val (extraParams, newCtx, _) = createSymbolsUpdateCtx(createSym)
@@ -932,18 +950,17 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
932950
val mainDefn = FunDefn(f.owner, f.sym, PlainParamList(extraParamsCpy) :: headPlistCopy :: Nil, bdy)
933951
val auxDefn = FunDefn(N, singleCallBms, flatPlist, lifted.body)
934952

935-
936-
Lifted(mainDefn, auxDefn :: extras)
953+
if ctx.firstClsFns.contains(f.sym) then
954+
Lifted(mainDefn, auxDefn :: extras)
955+
else
956+
Lifted(auxDefn, extras) // we can include just the flattened defn
937957
case c: ClsLikeDefn =>
938-
val createSym: String => (VarSymbol, LocalPath.PubField) =
939-
// due to the possibility of capturing a TempSymbol in HandlerLowering, it is necessary to generate a discriminator
940-
val fresh = FreshInt()
941-
(nme: String) =>
942-
val id = fresh.make
943-
(
944-
VarSymbol(Tree.Ident(nme + "$" + id)),
945-
LocalPath.PubField(c.isym, BlockMemberSymbol(nme + "$" + id, Nil, true))
946-
)
958+
val fresh = FreshInt()
959+
def createSym(nme: String): (VarSymbol, LocalPath.PubField) =
960+
(
961+
VarSymbol(Tree.Ident(nme)),
962+
LocalPath.PubField(c.isym, BlockMemberSymbol(nme, Nil, true))
963+
)
947964
val (extraParams, newCtx, flds) = createSymbolsUpdateCtx(createSym)
948965

949966
// add aux params, private fields, update preCtor
@@ -1263,7 +1280,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
12631280
val walker1 = new BlockTransformerShallow(SymbolSubst()):
12641281
override def applyBlock(b: Block): Block = b match
12651282
case Define(d, rest) =>
1266-
val (unliftable, modules, objects) = createMetadata(d, ctx)
1283+
val LifterMetadata(unliftable, modules, objects, firstClsFns) = createMetadata(d, ctx)
12671284

12681285
val modObjLocals = (modules ++ objects).map: c =>
12691286
analyzer.nestedIn.get(c.sym) match
@@ -1280,6 +1297,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
12801297
val ctxx = ctx
12811298
.addIgnored(unliftable)
12821299
.withModObjLocals(modObjLocals)
1300+
.withFirstClsFns(firstClsFns)
12831301

12841302
val Lifted(lifted, extra) = d match
12851303
case f: FunDefn =>

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,13 +1029,15 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
10291029
val desug = LambdaRewriter.desugar(blk)
10301030

10311031
val handlerPaths = new HandlerPaths
1032+
1033+
val (withHandlers, doUnwindPaths) = config.effectHandlers.fold((desug, Map.empty)): opt =>
1034+
HandlerLowering(handlerPaths, opt).translateTopLevel(desug)
1035+
10321036
val stackSafe = config.stackSafety match
1033-
case N => desug
1034-
case S(sts) => StackSafeTransform(sts.stackLimit, handlerPaths).transformTopLevel(desug)
1035-
val withHandlers = config.effectHandlers.fold(stackSafe): opt =>
1036-
HandlerLowering(handlerPaths, opt).translateTopLevel(stackSafe)
1037+
case N => withHandlers
1038+
case S(sts) => StackSafeTransform(sts.stackLimit, handlerPaths, doUnwindPaths).transformTopLevel(withHandlers)
10371039

1038-
val flattened = withHandlers.flattened
1040+
val flattened = stackSafe.flattened
10391041

10401042
val lifted =
10411043
if lift then Lifter(S(handlerPaths)).transform(flattened)

hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ import hkmc2.codegen.*
77
import hkmc2.semantics.Elaborator.State
88
import hkmc2.semantics.*
99
import hkmc2.syntax.Tree
10+
import hkmc2.codegen.HandlerLowering.FnOrCls
1011

11-
class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
12+
class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: Map[FnOrCls, Path])(using State):
1213
private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("stackDepth")
14+
15+
val doUnwindFns = doUnwindMap.values.collect:
16+
case s: Select if s.symbol.isDefined => s.symbol.get
17+
case Value.Ref(sym) => sym
18+
.toSet
1319

1420
private val runtimePath: Path = State.runtimeSymbol.asPath
1521
private val checkDepthPath: Path = runtimePath.selN(Tree.Ident("checkDepth"))
16-
private val resetDepthPath: Path = runtimePath.selN(Tree.Ident("resetDepth"))
1722
private val runStackSafePath: Path = runtimePath.selN(Tree.Ident("runStackSafe"))
1823
private val stackDepthPath: Path = runtimePath.selN(STACK_DEPTH_IDENT)
1924

@@ -24,18 +29,13 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
2429

2530
// Increases the stack depth, assigns the call to a value, then decreases the stack depth
2631
// then binds that value to a desired block
27-
def extractRes(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol) =
28-
if isTailCall then
29-
blockBuilder
30-
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1)))
31-
.ret(res)
32+
def extractRes(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol): Block =
33+
if isTailCall then Return(res, false)
3234
else
3335
val tmp = sym getOrElse TempSymbol(None, "tmp")
34-
val offsetGtDepth = TempSymbol(None, "offsetGtDepth")
3536
blockBuilder
36-
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1)))
3737
.assign(tmp, res)
38-
.assign(tmp, Call(resetDepthPath, tmp.asPath.asArg :: curDepth.asPath.asArg :: Nil)(true, false))
38+
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, curDepth.asPath)
3939
.rest(f(tmp.asPath))
4040

4141
def wrapStackSafe(body: Block, resSym: Local, rest: Block) =
@@ -51,6 +51,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
5151
def transform(b: Block, curDepth: => Symbol, isTopLevel: Bool = false): Block =
5252
def usesStack(r: Result) = r match
5353
case Call(Value.Ref(_: BuiltinSymbol), _) => false
54+
case c: Call if !c.mayRaiseEffects => false // a call can only trigger a stack delay if it can raise effects
5455
case _: Call | _: Instantiate => true
5556
case _ => false
5657

@@ -75,39 +76,20 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
7576
extract(r, false, _ => applyBlock(rest), S(lhs), curDepth)
7677
else
7778
super.applyBlock(b)
78-
case HandleBlock(l, res, par, args, cls, hdr, bod, rst) =>
79-
val l2 = applyLocal(l)
80-
val res2 = applyLocal(res)
81-
applyPath(par): par2 =>
82-
applyListOf(args, applyPath(_)(_)): args2 =>
83-
val cls2 = cls.subst
84-
val hdr2 = hdr.mapConserve(applyHandler)
85-
val bod2 = rewriteBlk(bod)
86-
val rst2 = applyBlock(rst)
87-
if isTopLevel then
88-
val newRes = TempSymbol(N, "res")
89-
val newHandler = HandleBlock(l2, newRes, par2, args2, cls2, hdr2, bod2, Ret(newRes.asPath))
90-
wrapStackSafe(newHandler, res2, rst2)
91-
else
92-
HandleBlock(l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
79+
80+
case HandleBlock(l, res, par, args, cls, hdr, bod, rst) => lastWords("HandleBlock in stack safe transformation")
9381

9482
case _ => super.applyBlock(b)
9583

96-
override def applyHandler(hdr: Handler): Handler =
97-
val sym2 = hdr.sym.subst
98-
val resumeSym2 = hdr.resumeSym.subst
99-
val params2 = hdr.params.mapConserve(applyParamList)
100-
val body2 = rewriteBlk(hdr.body)
101-
Handler(sym2, resumeSym2, params2, body2)
84+
override def applyHandler(hdr: Handler): Handler = lastWords("HandleBlock in stack safe transformation")
10285

10386
override def applyResult(r: Result)(k: Result => Block): Block =
10487
if usesStack(r) then
10588
extract(r, false, k, N, curDepth)
10689
else
10790
super.applyResult(r)(k)
10891

109-
override def applyLam(lam: Lambda): Lambda =
110-
Lambda(lam.params, rewriteBlk(lam.body))
92+
override def applyLam(lam: Lambda): Lambda = lastWords("Lambda in stack safe transformation")
11193

11294
transform.applyBlock(b)
11395

@@ -121,40 +103,75 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
121103
case _ => ()
122104
trivial
123105

124-
def rewriteCls(defn: ClsLikeDefn, isTopLevel: Bool): ClsLikeDefn =
125-
val ClsLikeDefn(owner, isym, sym, k, paramsOpt, auxParams,
126-
parentPath, methods, privateFields, publicFields, preCtor, ctor, mod) = defn
127-
ClsLikeDefn(
128-
owner, isym, sym, k, paramsOpt, auxParams, parentPath,
129-
methods.map(rewriteFn),
130-
privateFields,
131-
publicFields, rewriteBlk(preCtor),
132-
rewriteBlk(ctor),
133-
mod.map(rewriteObjBody(_, isTopLevel)),
134-
)
106+
def rewriteCls(defn: ClsLikeDefn, isTopLevel: Bool): ClsLikeDefn = defn.parentPath match
107+
case Some(value) if value eq paths.contClsPath => defn
108+
case _ =>
109+
val ClsLikeDefn(owner, isym, sym, k, paramsOpt, auxParams,
110+
parentPath, methods, privateFields, publicFields, preCtor, ctor, mod) = defn
111+
ClsLikeDefn(
112+
owner, isym, sym, k, paramsOpt, auxParams, parentPath,
113+
methods.map(rewriteFn),
114+
privateFields,
115+
publicFields,
116+
rewriteBlk(preCtor, L(BlockMemberSymbol("TODO", Nil)), 1), // TODO: preCtor is not translated in handler lowering
117+
if isTopLevel && (defn.k is syntax.Mod) then transformTopLevel(ctor) else rewriteBlk(ctor, R(isym), 1),
118+
mod.map(rewriteObjBody(_, isTopLevel)),
119+
)
135120

136121
def rewriteObjBody(defn: ClsLikeBody, isTopLevel: Bool): ClsLikeBody =
137122
ClsLikeBody(
138123
defn.isym,
139124
defn.methods.map(rewriteFn),
140125
defn.privateFields,
141126
defn.publicFields,
142-
if isTopLevel then transformTopLevel(defn.ctor) else rewriteBlk(defn.ctor),
127+
if isTopLevel then transformTopLevel(defn.ctor) else rewriteBlk(defn.ctor, R(defn.isym), 1),
143128
)
144129

145-
def rewriteBlk(blk: Block) =
146-
val curDepth =
130+
// fnOrCls points us to the doUnwind function
131+
def rewriteBlk(blk: Block, fnOrCls: FnOrCls, increment: Int) =
132+
var usedDepth = false
133+
lazy val curDepth =
134+
usedDepth = true
147135
TempSymbol(None, "curDepth")
136+
137+
val doUnwindPath = doUnwindMap.get(fnOrCls)
148138
val newBody = transform(blk, curDepth)
139+
149140
if isTrivial(blk) then
150141
newBody
151-
else
142+
else if doUnwindPath.isEmpty then
152143
val resSym = TempSymbol(None, "stackDelayRes")
153144
blockBuilder
154-
.assign(curDepth, stackDepthPath)
145+
.staticif(usedDepth, _.assign(curDepth, stackDepthPath))
146+
.rest(newBody)
147+
else
148+
val resSym = TempSymbol(None, "stackDelayRes")
149+
val rewritten = blockBuilder
150+
.staticif(usedDepth, _.assign(curDepth, stackDepthPath))
151+
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(increment)))
155152
.assign(resSym, Call(checkDepthPath, Nil)(true, true))
153+
.ifthen(
154+
resSym.asPath,
155+
Case.Cls(paths.effectSigSym, paths.effectSigPath),
156+
Return(
157+
Call(doUnwindPath.get, resSym.asPath.asArg :: intLit(0).asArg :: Nil)(true, false),
158+
false
159+
)
160+
)
156161
.rest(newBody)
157-
158-
def rewriteFn(defn: FunDefn) = FunDefn(defn.owner, defn.sym, defn.params, rewriteBlk(defn.body))
162+
// Float out defns, including the doUnwind function, so that they appear at the top of the block
163+
// This is because the doUnwind function must appear before the checks inserted by the stack
164+
// safety pass.
165+
// However, due to how tightly coupled the stack safety and handler lowering are, it might be
166+
// better to simply merge the two passes in the future.
167+
val (blk, defns) = doUnwindPath.get match
168+
case Value.Ref(sym) => rewritten.floatOutDefns()
169+
case _ => (rewritten, Nil)
170+
defns.foldLeft(blk)((acc, defn) => Define(defn, acc))
171+
172+
173+
def rewriteFn(defn: FunDefn) =
174+
if doUnwindFns.contains(defn.sym) then defn
175+
else FunDefn(defn.owner, defn.sym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))
159176

160177
def transformTopLevel(b: Block) = transform(b, TempSymbol(N), true)

hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State):
173173
blkAccessesShallow(f.body).withoutLocals(fVars)
174174
case c: ClsLikeDefn =>
175175
val methodSyms = c.methods.map(_.sym).toSet
176-
c.methods.foldLeft(blkAccessesShallow(c.preCtor) ++ blkAccessesShallow(c.ctor)):
176+
val ret = c.methods.foldLeft(blkAccessesShallow(c.preCtor) ++ blkAccessesShallow(c.ctor)):
177177
case (acc, fDefn) =>
178178
// class methods do not need to be lifted, so we don't count calls to their methods.
179179
// a previous reference to this class's block member symbol is enough to assume any
@@ -182,6 +182,11 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State):
182182
// however, we must keep references to the class itself!
183183
val defnAccess = findAccessesShallow(fDefn)
184184
acc ++ defnAccess.withoutBms(methodSyms)
185+
if c.parentPath.isDefined && isHandlerClsPath(c.parentPath.get) then
186+
// for continuation classes, treat them like they only read variables
187+
AccessInfo(ret.accessed ++ ret.mutated, Set.empty, ret.refdDefns)
188+
else
189+
ret
185190
case _: ValDefn => AccessInfo.empty
186191

187192
accessedCache.getOrElseUpdate(defn.sym, create)
@@ -333,16 +338,7 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State):
333338

334339
def handleCalledBms(called: BlockMemberSymbol): Unit = defnSyms.get(called) match
335340
case None => ()
336-
case Some(defn) =>
337-
// special case continuation classes
338-
defn match
339-
case c: ClsLikeDefn => c.parentPath match
340-
case S(path) if isHandlerClsPath(path) => return
341-
// treat the continuation class as if it does not exist
342-
case _ => ()
343-
case _ => ()
344-
345-
341+
case Some(defn) =>
346342
val AccessInfo(accessed, muted, refd) = accessMap(defn.sym)
347343
val muts = muted.intersect(thisVars)
348344
val reads = defn.freeVars.intersect(thisVars) -- muts

0 commit comments

Comments
 (0)