Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
val l = loweringCtx.registerTempSymbol(N)
Assign(l, r, setupTerm("Else", Value.Ref(l) :: Nil)(k))
case Split.End => setupTerm("End", Nil)(k)
case Split.LetSplit(sym, tail) =>
quoteSplit(sym.body): r1 =>
val l1 = loweringCtx.registerTempSymbol(N)
blockBuilder.assign(l1, r1)
.chain(b => quoteSplit(tail)(r2 => Assign(loweringCtx.registerTempSymbol(N), r2, b)))
.rest(setupTerm("LetSplit", Value.Ref(l1) :: Nil)(k))
case Split.UseSplit(sym) =>
setupTerm("UseSplit", Nil)(k)

lazy val setupFilename: Path =
val state = summon[State]
Expand Down
4 changes: 4 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
val (ty, res) = typeCheck(alts)
(ty, res, Nil, true)
case Split.End => (Bot, Bot, Nil, false)
case Split.LetSplit(sym, tail) => typeADTMatch(tail, sign)
case Split.UseSplit(sym) => typeADTMatch(sym.body, sign)

private def typeSplit
(split: Split, sign: Opt[GeneralType])(using ctx: InvalCtx)(using CCtx, Scope)
Expand Down Expand Up @@ -442,6 +444,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case S(sign) => ascribe(alts, sign)
case _ => typeCheck(alts)
case Split.End => (Bot, Bot)
case Split.LetSplit(sym, tail) => typeSplit(tail, sign)
case Split.UseSplit(sym) => typeSplit(sym.body, sign)

private def typeAllSplits
(split: Split, sign: Opt[GeneralType])(using ctx: InvalCtx)(using CCtx, Scope)
Expand Down
2 changes: 2 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Resolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,8 @@ object ModuleChecker:
case Split.Let(_, _, tail) => go(tail)
case Split.Else(term) => term :: Nil
case Split.End => Nil
case Split.LetSplit(sym, tail) => go(sym.body) ::: go(tail)
case Split.UseSplit(_) => Nil
go(s)

/** Checks if a symbol is of a type parameter. */
Expand Down
42 changes: 41 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/semantics/Split.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ enum Split extends AutoLocated with ProductWithTail:
case Let(sym: BlockLocalSymbol, term: Term, tail: Split)
case Else(default: Term)
case End
/** Declares a named split (join point). The symbol's `body` holds the shared
* split; `tail` is the continuation where the symbol is in scope. */
case LetSplit(sym: SplitSymbol, tail: Split)
/** References a previously declared split (join point). */
case UseSplit(sym: SplitSymbol)

inline def ~:(head: Branch): Split = Split.Cons(head, this)

Expand All @@ -31,6 +36,8 @@ enum Split extends AutoLocated with ProductWithTail:
case Let(sym, term, tail) => Let(sym, term.mkClone, tail.mkClone)
case Else(default) => Else(default.mkClone)
case End => End
case LetSplit(sym, tail) => LetSplit(sym, tail.mkClone)
case UseSplit(sym) => UseSplit(sym)

/** Used to indicate whether the `Split` was duplicated during desugaring or
* normalization. */
Expand All @@ -47,31 +54,52 @@ enum Split extends AutoLocated with ProductWithTail:
case Cons(head, tail) => Cons(head, tail.duplicate)
case Let(name, term, tail) => Let(name, term, tail.duplicate)
case Else(default) => Else(default)
case End => End).setDuplicated
case End => End
case LetSplit(sym, tail) => LetSplit(sym, tail.duplicate)
case UseSplit(sym) => UseSplit(sym)).setDuplicated

lazy val isFull: Bool = this match
case Split.Cons(_, tail) => tail.isFull
case Split.Let(_, _, tail) => tail.isFull
case Split.Else(_) => true
case Split.End => false
case Split.LetSplit(_, tail) => tail.isFull
case Split.UseSplit(sym) => sym.body.isFull

lazy val isEmpty: Bool = this match
case Split.Let(_, _, tail) => tail.isEmpty
case Split.Else(_) | Split.Cons(_, _) => false
case Split.End => true
case Split.LetSplit(_, tail) => tail.isEmpty
case Split.UseSplit(sym) => sym.body.isEmpty

/** Approximate tree size, used to decide whether sharing via `LetSplit` is
* worthwhile compared to inlining (see `patMatConsequentSharingThreshold`). */
lazy val size: Int = this match
case Split.Cons(Branch(scrutinee, _, continuation), tail) =>
scrutinee.size + continuation.size + tail.size + 1
case Split.Let(_, term, tail) => term.size + tail.size + 1
case Split.Else(term) => term.size
case Split.End => 0
case Split.LetSplit(_, tail) => tail.size
case Split.UseSplit(_) => 0

final override def children: Vector[Located] = this match
case Split.Cons(head, tail) => Vector.double(head, tail)
case Split.Let(name, term, tail) => Vector.triple(name, term, tail)
case Split.Else(default) => Vector.single(default)
case Split.End => Vector.empty
case Split.LetSplit(sym, tail) => Vector.double(sym, tail)
case Split.UseSplit(sym) => Vector.single(sym)

def subTerms: Vector[Term] = this match
case Split.Cons(Branch(scrutinee, pattern, continuation), tail) =>
scrutinee +: (pattern.subTerms ++ continuation.subTerms ++ tail.subTerms)
case Split.Let(_, term, tail) => term +: tail.subTerms
case Split.Else(term) => Vector.single(term)
case Split.End => Vector.empty
case Split.LetSplit(sym, tail) => sym.body.subTerms ++ tail.subTerms
case Split.UseSplit(_) => Vector.empty

/** Free variable names, accounting for let bindings. */
lazy val freeVars: Set[Str] = this match
Expand All @@ -82,12 +110,16 @@ enum Split extends AutoLocated with ProductWithTail:
term.freeVars ++ (tail.freeVars - sym.nme)
case Split.Else(term) => term.freeVars
case Split.End => Set.empty
case Split.LetSplit(sym, tail) => sym.body.freeVars ++ tail.freeVars
case Split.UseSplit(sym) => sym.body.freeVars

final def showDbg(using DebugPrinter): String = this match
case Split.Cons(head, tail) => s"${head.showDbg}; ${tail.showDbg}"
case Split.Let(name, term, tail) => s"let ${name} = ${term.showDbg}; ${tail.showDbg}"
case Split.Else(default) => s"else ${default.showDbg}"
case Split.End => ""
case Split.LetSplit(sym, tail) => s"let-split ${sym.nme} = { ${sym.body.showDbg} }; ${tail.showDbg}"
case Split.UseSplit(sym) => s"$$${sym.nme}"

final override def withLoc(loco: Option[Loc]): this.type =
super.withLoc:
Expand All @@ -96,6 +128,7 @@ enum Split extends AutoLocated with ProductWithTail:
// which causes the assertion of distinctness of origins to fail.
case Split.End => N
case _: Split.Else => N // FIXME: @Luyu pls clean up this mess
case _: Split.UseSplit => N
case _ => loco

var isFallback: Bool = false
Expand All @@ -112,6 +145,8 @@ extension (split: Split)
case Split.Let(name, term, tail) => Split.Let(name, term, tail ~~: fallback)
case Split.Else(_) => lastWords("impossible since split is not full")
case Split.End => fallback
case Split.LetSplit(sym, tail) => Split.LetSplit(sym, tail ~~: fallback)
case Split.UseSplit(_) => split // UseSplit is terminal; the referenced body determines fullness

object Split:
def default(term: Term): Split = Split.Else(term)
Expand Down Expand Up @@ -186,6 +221,11 @@ object Split:
case Split.Else(t) =>
(if isFirst && !isTopLevel then "" else "else") #: term(t)
case Split.End => Nil
case Split.LetSplit(sym, tail) =>
val bodyLines = split(sym.body, true, true)
(s"let-split ${sym.nme} =" #: bodyLines) ::: split(tail, false, isTopLevel)
case Split.UseSplit(sym) =>
(0, s"$$${sym.nme}") :: Nil
if s.duplicated then lines.map:
case (n, line) if !line.endsWith("// duplicated") => (n, s"$line // duplicated")
case other => other
Expand Down
10 changes: 10 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ class LabelSymbol(val trm: Opt[Term], name: Str = "lbl")(using State) extends Lo
def toLoc = trm.flatMap(_.toLoc)
override def prefix: Str = "label:"

/** Symbol representing a named split (join point) introduced during normalization.
* The `body` field holds the shared split that this symbol references.
* The `label` field is set during lowering to the corresponding LabelSymbol. */
class SplitSymbol(var body: Split, name: Str = "split")(using State) extends LocalSymbol:
var label: Opt[LabelSymbol] = N
def nme = name
def subst(using s: SymbolSubst): SplitSymbol = this // SplitSymbols are not substituted
def toLoc = body.toLoc
override def prefix: Str = "split:"

abstract class BlockLocalSymbol(name: Str)(using State) extends FlowSymbol(name):
self: LocalSymbol => // * using `with LocalSymbol` in the `extends` clause makes Scala think there's a bad override
var decl: Opt[Declaration] = N
Expand Down
Loading