@@ -7,13 +7,18 @@ import hkmc2.codegen.*
77import hkmc2 .semantics .Elaborator .State
88import hkmc2 .semantics .*
99import hkmc2 .syntax .Tree
10+ import hkmc2 .codegen .HandlerLowering .FnOrCls
1011
11- class StackSafeTransform (depthLimit : Int , paths : HandlerPaths )(using State ):
12+ class StackSafeTransform (depthLimit : Int , paths : HandlerPaths , doUnwindMap : Map [ FnOrCls , Path ] )(using State ):
1213 private val STACK_DEPTH_IDENT : Tree .Ident = Tree .Ident (" stackDepth" )
14+
15+ val doUnwindFns = doUnwindMap.values.collect:
16+ case s : Select if s.symbol.isDefined => s.symbol.get
17+ case Value .Ref (sym) => sym
18+ .toSet
1319
1420 private val runtimePath : Path = State .runtimeSymbol.asPath
1521 private val checkDepthPath : Path = runtimePath.selN(Tree .Ident (" checkDepth" ))
16- private val resetDepthPath : Path = runtimePath.selN(Tree .Ident (" resetDepth" ))
1722 private val runStackSafePath : Path = runtimePath.selN(Tree .Ident (" runStackSafe" ))
1823 private val stackDepthPath : Path = runtimePath.selN(STACK_DEPTH_IDENT )
1924
@@ -24,18 +29,13 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
2429
2530 // Increases the stack depth, assigns the call to a value, then decreases the stack depth
2631 // then binds that value to a desired block
27- def extractRes (res : Result , isTailCall : Bool , f : Result => Block , sym : Option [Symbol ], curDepth : => Symbol ) =
28- if isTailCall then
29- blockBuilder
30- .assignFieldN(runtimePath, STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
31- .ret(res)
32+ def extractRes (res : Result , isTailCall : Bool , f : Result => Block , sym : Option [Symbol ], curDepth : => Symbol ): Block =
33+ if isTailCall then Return (res, false )
3234 else
3335 val tmp = sym getOrElse TempSymbol (None , " tmp" )
34- val offsetGtDepth = TempSymbol (None , " offsetGtDepth" )
3536 blockBuilder
36- .assignFieldN(runtimePath, STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
3737 .assign(tmp, res)
38- .assign(tmp, Call (resetDepthPath, tmp.asPath.asArg :: curDepth.asPath.asArg :: Nil )( true , false ) )
38+ .assignFieldN(runtimePath, STACK_DEPTH_IDENT , curDepth.asPath)
3939 .rest(f(tmp.asPath))
4040
4141 def wrapStackSafe (body : Block , resSym : Local , rest : Block ) =
@@ -51,6 +51,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
5151 def transform (b : Block , curDepth : => Symbol , isTopLevel : Bool = false ): Block =
5252 def usesStack (r : Result ) = r match
5353 case Call (Value .Ref (_ : BuiltinSymbol ), _) => false
54+ case c : Call if ! c.mayRaiseEffects => false // a call can only trigger a stack delay if it can raise effects
5455 case _ : Call | _ : Instantiate => true
5556 case _ => false
5657
@@ -75,39 +76,20 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
7576 extract(r, false , _ => applyBlock(rest), S (lhs), curDepth)
7677 else
7778 super .applyBlock(b)
78- case HandleBlock (l, res, par, args, cls, hdr, bod, rst) =>
79- val l2 = applyLocal(l)
80- val res2 = applyLocal(res)
81- applyPath(par): par2 =>
82- applyListOf(args, applyPath(_)(_)): args2 =>
83- val cls2 = cls.subst
84- val hdr2 = hdr.mapConserve(applyHandler)
85- val bod2 = rewriteBlk(bod)
86- val rst2 = applyBlock(rst)
87- if isTopLevel then
88- val newRes = TempSymbol (N , " res" )
89- val newHandler = HandleBlock (l2, newRes, par2, args2, cls2, hdr2, bod2, Ret (newRes.asPath))
90- wrapStackSafe(newHandler, res2, rst2)
91- else
92- HandleBlock (l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
79+
80+ case HandleBlock (l, res, par, args, cls, hdr, bod, rst) => lastWords(" HandleBlock in stack safe transformation" )
9381
9482 case _ => super .applyBlock(b)
9583
96- override def applyHandler (hdr : Handler ): Handler =
97- val sym2 = hdr.sym.subst
98- val resumeSym2 = hdr.resumeSym.subst
99- val params2 = hdr.params.mapConserve(applyParamList)
100- val body2 = rewriteBlk(hdr.body)
101- Handler (sym2, resumeSym2, params2, body2)
84+ override def applyHandler (hdr : Handler ): Handler = lastWords(" HandleBlock in stack safe transformation" )
10285
10386 override def applyResult (r : Result )(k : Result => Block ): Block =
10487 if usesStack(r) then
10588 extract(r, false , k, N , curDepth)
10689 else
10790 super .applyResult(r)(k)
10891
109- override def applyLam (lam : Lambda ): Lambda =
110- Lambda (lam.params, rewriteBlk(lam.body))
92+ override def applyLam (lam : Lambda ): Lambda = lastWords(" Lambda in stack safe transformation" )
11193
11294 transform.applyBlock(b)
11395
@@ -121,40 +103,75 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
121103 case _ => ()
122104 trivial
123105
124- def rewriteCls (defn : ClsLikeDefn , isTopLevel : Bool ): ClsLikeDefn =
125- val ClsLikeDefn (owner, isym, sym, k, paramsOpt, auxParams,
126- parentPath, methods, privateFields, publicFields, preCtor, ctor, mod) = defn
127- ClsLikeDefn (
128- owner, isym, sym, k, paramsOpt, auxParams, parentPath,
129- methods.map(rewriteFn),
130- privateFields,
131- publicFields, rewriteBlk(preCtor),
132- rewriteBlk(ctor),
133- mod.map(rewriteObjBody(_, isTopLevel)),
134- )
106+ def rewriteCls (defn : ClsLikeDefn , isTopLevel : Bool ): ClsLikeDefn = defn.parentPath match
107+ case Some (value) if value eq paths.contClsPath => defn
108+ case _ =>
109+ val ClsLikeDefn (owner, isym, sym, k, paramsOpt, auxParams,
110+ parentPath, methods, privateFields, publicFields, preCtor, ctor, mod) = defn
111+ ClsLikeDefn (
112+ owner, isym, sym, k, paramsOpt, auxParams, parentPath,
113+ methods.map(rewriteFn),
114+ privateFields,
115+ publicFields,
116+ rewriteBlk(preCtor, L (BlockMemberSymbol (" TODO" , Nil )), 1 ), // TODO: preCtor is not translated in handler lowering
117+ if isTopLevel && (defn.k is syntax.Mod ) then transformTopLevel(ctor) else rewriteBlk(ctor, R (isym), 1 ),
118+ mod.map(rewriteObjBody(_, isTopLevel)),
119+ )
135120
136121 def rewriteObjBody (defn : ClsLikeBody , isTopLevel : Bool ): ClsLikeBody =
137122 ClsLikeBody (
138123 defn.isym,
139124 defn.methods.map(rewriteFn),
140125 defn.privateFields,
141126 defn.publicFields,
142- if isTopLevel then transformTopLevel(defn.ctor) else rewriteBlk(defn.ctor),
127+ if isTopLevel then transformTopLevel(defn.ctor) else rewriteBlk(defn.ctor, R (defn.isym), 1 ),
143128 )
144129
145- def rewriteBlk (blk : Block ) =
146- val curDepth =
130+ // fnOrCls points us to the doUnwind function
131+ def rewriteBlk (blk : Block , fnOrCls : FnOrCls , increment : Int ) =
132+ var usedDepth = false
133+ lazy val curDepth =
134+ usedDepth = true
147135 TempSymbol (None , " curDepth" )
136+
137+ val doUnwindPath = doUnwindMap.get(fnOrCls)
148138 val newBody = transform(blk, curDepth)
139+
149140 if isTrivial(blk) then
150141 newBody
151- else
142+ else if doUnwindPath.isEmpty then
152143 val resSym = TempSymbol (None , " stackDelayRes" )
153144 blockBuilder
154- .assign(curDepth, stackDepthPath)
145+ .staticif(usedDepth, _.assign(curDepth, stackDepthPath))
146+ .rest(newBody)
147+ else
148+ val resSym = TempSymbol (None , " stackDelayRes" )
149+ val rewritten = blockBuilder
150+ .staticif(usedDepth, _.assign(curDepth, stackDepthPath))
151+ .assignFieldN(runtimePath, STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(increment)))
155152 .assign(resSym, Call (checkDepthPath, Nil )(true , true ))
153+ .ifthen(
154+ resSym.asPath,
155+ Case .Cls (paths.effectSigSym, paths.effectSigPath),
156+ Return (
157+ Call (doUnwindPath.get, resSym.asPath.asArg :: intLit(0 ).asArg :: Nil )(true , false ),
158+ false
159+ )
160+ )
156161 .rest(newBody)
157-
158- def rewriteFn (defn : FunDefn ) = FunDefn (defn.owner, defn.sym, defn.params, rewriteBlk(defn.body))
162+ // Float out defns, including the doUnwind function, so that they appear at the top of the block
163+ // This is because the doUnwind function must appear before the checks inserted by the stack
164+ // safety pass.
165+ // However, due to how tightly coupled the stack safety and handler lowering are, it might be
166+ // better to simply merge the two passes in the future.
167+ val (blk, defns) = doUnwindPath.get match
168+ case Value .Ref (sym) => rewritten.floatOutDefns()
169+ case _ => (rewritten, Nil )
170+ defns.foldLeft(blk)((acc, defn) => Define (defn, acc))
171+
172+
173+ def rewriteFn (defn : FunDefn ) =
174+ if doUnwindFns.contains(defn.sym) then defn
175+ else FunDefn (defn.owner, defn.sym, defn.params, rewriteBlk(defn.body, L (defn.sym), 1 ))
159176
160177 def transformTopLevel (b : Block ) = transform(b, TempSymbol (N ), true )
0 commit comments