Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ case class Config(
sanityChecks: Opt[SanityChecks],
effectHandlers: Opt[EffectHandlers],
liftDefns: Opt[LiftDefns],
patMatConsequentSharingThreshold: Opt[Int],
stageCode: Bool,
target: CompilationTarget,
rewriteWhileLoops: Bool,
Expand Down Expand Up @@ -61,7 +60,6 @@ object Config:
// sanityChecks = S(SanityChecks(light = true)),
effectHandlers = N,
liftDefns = N,
patMatConsequentSharingThreshold = default.patMatConsequentSharingThreshold, // minimum: 1
target = CompilationTarget.JS,
rewriteWhileLoops = false,
stageCode = false,
Expand All @@ -74,8 +72,7 @@ object Config:
noFreeze = false,
noModuleCheck = false,
)
object default:
val patMatConsequentSharingThreshold = S(15)
object default

case class SanityChecks(light: Bool)

Expand Down Expand Up @@ -254,10 +251,6 @@ object ConfigParser:
parseOpt(value)(_ => S(Config.SanityChecks(light = true))) match
case S(v) => _.copy(sanityChecks = v)
case N => identity
case "patMatConsequentSharingThreshold" =>
parseInt(value) match
case S(v) => _.copy(patMatConsequentSharingThreshold = S(v))
case N => identity
case _ =>
raise(ErrorReport(
msg"Unknown config field '${name}'" -> value.toLoc :: Nil,
Expand Down
8 changes: 8 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
val l = loweringCtx.registerTempSymbol(N)
Assign(l, r, setupTerm("Else", Value.Ref(l) :: Nil)(k))
case Split.End => setupTerm("End", Nil)(k)
case Split.LetSplit(sym, tail) =>
quoteSplit(sym.body): r1 =>
val l1 = loweringCtx.registerTempSymbol(N)
blockBuilder.assign(l1, r1)
.chain(b => quoteSplit(tail)(r2 => Assign(loweringCtx.registerTempSymbol(N), r2, b)))
.rest(setupTerm("LetSplit", Value.Ref(l1) :: Nil)(k))
case Split.UseSplit(sym) =>
setupTerm("UseSplit", Nil)(k)

lazy val setupFilename: Path =
val state = summon[State]
Expand Down
4 changes: 4 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/invalml/InvalML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
val (ty, res) = typeCheck(alts)
(ty, res, Nil, true)
case Split.End => (Bot, Bot, Nil, false)
case Split.LetSplit(sym, tail) => typeADTMatch(tail, sign)
case Split.UseSplit(sym) => typeADTMatch(sym.body, sign)

private def typeSplit
(split: Split, sign: Opt[GeneralType])(using ctx: InvalCtx)(using CCtx, Scope)
Expand Down Expand Up @@ -442,6 +444,8 @@ class InvalTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
case S(sign) => ascribe(alts, sign)
case _ => typeCheck(alts)
case Split.End => (Bot, Bot)
case Split.LetSplit(sym, tail) => typeSplit(tail, sign)
case Split.UseSplit(sym) => typeSplit(sym.body, sign)

private def typeAllSplits
(split: Split, sign: Opt[GeneralType])(using ctx: InvalCtx)(using CCtx, Scope)
Expand Down
2 changes: 2 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Resolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,8 @@ object ModuleChecker:
case Split.Let(_, _, tail) => go(tail)
case Split.Else(term) => term :: Nil
case Split.End => Nil
case Split.LetSplit(sym, tail) => go(sym.body) ::: go(tail)
case Split.UseSplit(_) => Nil
go(s)

/** Checks if a symbol is of a type parameter. */
Expand Down
31 changes: 30 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/semantics/Split.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ enum Split extends AutoLocated with ProductWithTail:
case Let(sym: BlockLocalSymbol, term: Term, tail: Split)
case Else(default: Term)
case End
/** Declares a named split (join point). The symbol's `body` holds the shared
* split; `tail` is the continuation where the symbol is in scope. */
case LetSplit(sym: SplitSymbol, tail: Split)
/** References a previously declared split (join point). */
case UseSplit(sym: SplitSymbol)

inline def ~:(head: Branch): Split = Split.Cons(head, this)

Expand All @@ -31,6 +36,8 @@ enum Split extends AutoLocated with ProductWithTail:
case Let(sym, term, tail) => Let(sym, term.mkClone, tail.mkClone)
case Else(default) => Else(default.mkClone)
case End => End
case LetSplit(sym, tail) => LetSplit(sym, tail.mkClone)
case UseSplit(sym) => UseSplit(sym)

/** Used to indicate whether the `Split` was duplicated during desugaring or
* normalization. */
Expand All @@ -47,31 +54,41 @@ enum Split extends AutoLocated with ProductWithTail:
case Cons(head, tail) => Cons(head, tail.duplicate)
case Let(name, term, tail) => Let(name, term, tail.duplicate)
case Else(default) => Else(default)
case End => End).setDuplicated
case End => End
case LetSplit(sym, tail) => LetSplit(sym, tail.duplicate)
case UseSplit(sym) => UseSplit(sym)).setDuplicated

lazy val isFull: Bool = this match
case Split.Cons(_, tail) => tail.isFull
case Split.Let(_, _, tail) => tail.isFull
case Split.Else(_) => true
case Split.End => false
case Split.LetSplit(_, tail) => tail.isFull
case Split.UseSplit(sym) => sym.body.isFull

lazy val isEmpty: Bool = this match
case Split.Let(_, _, tail) => tail.isEmpty
case Split.Else(_) | Split.Cons(_, _) => false
case Split.End => true
case Split.LetSplit(_, tail) => tail.isEmpty
case Split.UseSplit(sym) => sym.body.isEmpty

final override def children: Vector[Located] = this match
case Split.Cons(head, tail) => Vector.double(head, tail)
case Split.Let(name, term, tail) => Vector.triple(name, term, tail)
case Split.Else(default) => Vector.single(default)
case Split.End => Vector.empty
case Split.LetSplit(sym, tail) => Vector.double(sym, tail)
case Split.UseSplit(sym) => Vector.single(sym)

def subTerms: Vector[Term] = this match
case Split.Cons(Branch(scrutinee, pattern, continuation), tail) =>
scrutinee +: (pattern.subTerms ++ continuation.subTerms ++ tail.subTerms)
case Split.Let(_, term, tail) => term +: tail.subTerms
case Split.Else(term) => Vector.single(term)
case Split.End => Vector.empty
case Split.LetSplit(sym, tail) => sym.body.subTerms ++ tail.subTerms
case Split.UseSplit(_) => Vector.empty

/** Free variable names, accounting for let bindings. */
lazy val freeVars: Set[Str] = this match
Expand All @@ -82,12 +99,16 @@ enum Split extends AutoLocated with ProductWithTail:
term.freeVars ++ (tail.freeVars - sym.nme)
case Split.Else(term) => term.freeVars
case Split.End => Set.empty
case Split.LetSplit(sym, tail) => sym.body.freeVars ++ tail.freeVars
case Split.UseSplit(sym) => sym.body.freeVars

final def showDbg(using DebugPrinter): String = this match
case Split.Cons(head, tail) => s"${head.showDbg}; ${tail.showDbg}"
case Split.Let(name, term, tail) => s"let ${name} = ${term.showDbg}; ${tail.showDbg}"
case Split.Else(default) => s"else ${default.showDbg}"
case Split.End => ""
case Split.LetSplit(sym, tail) => s"let-split ${sym.nme} = { ${sym.body.showDbg} }; ${tail.showDbg}"
case Split.UseSplit(sym) => s"$$${sym.nme}"

final override def withLoc(loco: Option[Loc]): this.type =
super.withLoc:
Expand All @@ -96,6 +117,7 @@ enum Split extends AutoLocated with ProductWithTail:
// which causes the assertion of distinctness of origins to fail.
case Split.End => N
case _: Split.Else => N // FIXME: @Luyu pls clean up this mess
case _: Split.UseSplit => N
case _ => loco

var isFallback: Bool = false
Expand All @@ -112,6 +134,8 @@ extension (split: Split)
case Split.Let(name, term, tail) => Split.Let(name, term, tail ~~: fallback)
case Split.Else(_) => lastWords("impossible since split is not full")
case Split.End => fallback
case Split.LetSplit(sym, tail) => Split.LetSplit(sym, tail ~~: fallback)
case Split.UseSplit(_) => split // UseSplit is terminal; the referenced body determines fullness

object Split:
def default(term: Term): Split = Split.Else(term)
Expand Down Expand Up @@ -186,6 +210,11 @@ object Split:
case Split.Else(t) =>
(if isFirst && !isTopLevel then "" else "else") #: term(t)
case Split.End => Nil
case Split.LetSplit(sym, tail) =>
val bodyLines = split(sym.body, true, true)
(s"let-split ${sym.nme} =" #: bodyLines) ::: split(tail, false, isTopLevel)
case Split.UseSplit(sym) =>
(0, s"$$${sym.nme}") :: Nil
if s.duplicated then lines.map:
case (n, line) if !line.endsWith("// duplicated") => (n, s"$line // duplicated")
case other => other
Expand Down
10 changes: 10 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ class LabelSymbol(val trm: Opt[Term], name: Str = "lbl")(using State) extends Lo
def toLoc = trm.flatMap(_.toLoc)
override def prefix: Str = "label:"

/** Symbol representing a named split (join point) introduced during normalization.
* The `body` field holds the shared split that this symbol references.
* The `label` field is set during lowering to the corresponding LabelSymbol. */
class SplitSymbol(var body: Split, name: Str = "split")(using State) extends LocalSymbol:
var label: Opt[LabelSymbol] = N
def nme = name
def subst(using s: SymbolSubst): SplitSymbol = this // SplitSymbols are not substituted
def toLoc = body.toLoc
override def prefix: Str = "split:"

abstract class BlockLocalSymbol(name: Str)(using State) extends FlowSymbol(name):
self: LocalSymbol => // * using `with LocalSymbol` in the `extends` clause makes Scala think there's a bad override
var decl: Opt[Declaration] = N
Expand Down
Loading
Loading