Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
subTerm_nonTail(fld): f =>
subTerm_nonTail(rhs): r =>
AssignDynField(p, f, ai, r, k(unit))
case sel @ SelProj(prefix, _, proj) =>
subTerm(prefix): p =>
subTerm_nonTail(rhs): r =>
AssignField(p, proj, r, k(unit))(sel.sym)
case _ => fail:
ErrorReport(
msg"Unexpected left-hand side in assignment (${lhs.describe})" -> lhs.toLoc :: Nil, S(lhs),
Expand Down
141 changes: 115 additions & 26 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
&& ((defn.k is syntax.Cls) || (defn.k is syntax.Obj))
&& defn.auxParams.isEmpty
&& (!(defn.k is syntax.Obj) || defn.parentPath.isEmpty)
&& defn.methods.isEmpty
&& (!(defn.k is syntax.Obj) || defn.methods.isEmpty)
&& defn.companion.isEmpty

/** Returns singleton metadata when `sym` resolves to a registered singleton object. */
Expand Down Expand Up @@ -276,7 +276,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
ctx.getRuntimeClassTags(childSym).getOrElse(lastWords("unreachable"))
ctx.registerRuntimeClassTags(defn.sym, LinkedHashSet(ownTag) ++ childTags)

/** Declares the shared Wasm function type used by a class ctor/init placeholder. */
/** Declares the shared Wasm function type used by a class-associated function placeholder. */
private def declareClassFuncType(
defn: ClsLikeDefn,
suffix: Str,
Expand Down Expand Up @@ -306,7 +306,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
private def initFuncSym(sym: BlockMemberSymbol): BlockMemberSymbol =
initFuncSyms.getOrElseUpdate(sym, BlockMemberSymbol(s"${sym.nme}_init", Nil, nameIsMeaningful = false))

/** Registers a placeholder class ctor/init function so later lowering can overwrite it. */
/** Registers a placeholder class-associated function so later lowering can overwrite it. */
private def predeclareClassFunc(
defn: ClsLikeDefn,
suffix: Str,
Expand Down Expand Up @@ -354,6 +354,28 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
s"${sym.nme}_ctor"
predeclareClassFunc(defn, "ctor", ctorParams, S(defn.sym), N, ctorId)

/** Declares one top-level class method. */
private def predeclareMethod(methodDefn: FunDefn, ownerCls: ClsLikeDefn)(using Ctx, Raise, Scope): Unit =
val methodParams = (ownerCls.isym -> "this") +:
methodDefn.params.headOption.fold(Nil): ps =>
ps.params.map: p =>
p.sym -> p.sym.nme
val methodId = ownerCls.sym
.optionIf: sym =>
!(ownerCls.k is syntax.Obj) && sym.nameIsMeaningful
.map: sym =>
s"${sym.nme}_${methodDefn.sym.nme}"
predeclareClassFunc(ownerCls, methodDefn.sym.nme, methodParams, S(methodDefn.sym), methodId, N)

/** Declares placeholders for all methods on one top-level class. */
private def predeclareClassMethods(defn: ClsLikeDefn)(using Ctx, Raise, Scope): Unit =
defn.methods.foreach:
case methodDefn @ FunDefn(_, _, _, Nil, _) =>
predeclareMethod(methodDefn, defn)
case methodDefn @ FunDefn(_, _, _, _ :: Nil, _) =>
predeclareMethod(methodDefn, defn)
case _ => ()

/** Gets (and caches) the exception tag used for MLX `throw`. */
private def exnTagIdx(using Ctx, Raise, Scope): TagIdx =
val symNme = scope.allocateName(TempSymbol(N, "mlx_exn"))
Expand Down Expand Up @@ -775,6 +797,11 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
))
end fieldSelect

/** Resolves `sym` to a predeclared class method symbol, if any. */
private def predeclaredClassMethodSym(sym: DefinitionSymbol[?])(using Ctx): Opt[BlockMemberSymbol] =
sym.asBlkMember.filter: methodSym =>
methodSym.asTrm.exists(_.owner.exists(_.asCls.isDefined)) && ctx.getFunc(methodSym).nonEmpty

def result(r: codegen.Result)(using Ctx, Raise, Scope): Expr = r match
case Value.This(sym) =>
// TODO(Derppening): Add type tracking and refinement for locals, remove the `ref.cast`
Expand Down Expand Up @@ -825,6 +852,14 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
else
errExpr(Ls(msg"Cannot call non-binary builtin symbol '${l.nme}'" -> r.toLoc))

case Call(sel @ Select(qual, _), args) if sel.symbol.flatMap(predeclaredClassMethodSym).nonEmpty =>
val methodSym = sel.symbol.flatMap(predeclaredClassMethodSym).get
call(
funcidx = ctx.getFunc_!(methodSym),
operands = result(qual) +: args.map(argument),
returnTypes = Seq(Result(RefType.anyref)),
)

case c @ Call(fun, args) =>
wasmIntrinsicName(fun) match
case S(intrName) =>
Expand Down Expand Up @@ -902,6 +937,24 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
extraInfo = S(sel),
)

case S(selSym) if predeclaredClassMethodSym(selSym).nonEmpty =>
val methodSym = predeclaredClassMethodSym(selSym).get
methodSym.asTrm.flatMap(_.defn) match
case S(defn: TermDefinition) if defn.params.isEmpty =>
call(
funcidx = ctx.getFunc_!(methodSym),
operands = Seq(result(qual)),
returnTypes = Seq(Result(RefType.anyref)),
)
case _ =>
errExpr(
Ls(
msg"`${methodSym.toString}` is neither a field access nor a callable method" ->
sel.toLoc,
),
extraInfo = S(sel),
)

case S(selSym: TermSymbol) =>
val qualRes = result(qual)
val selOwner = selSym.owner getOrElse:
Expand Down Expand Up @@ -1186,24 +1239,26 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
val lhsExpr = result(lhs)
val rhsExpr = result(rhs)
val assignInstr = assign.symbol match
case S(selSym: TermSymbol) =>
val selOwner = selSym.owner getOrElse
lastWords(s"Expected resolved AssignField(...) expression `$selSym` to have an owner")
val selCls = selOwner.asBlkMember getOrElse
lastWords(
s"Expected resolved class for AssignField(...) expression to be a BlockMemberSymbol, but got $selOwner (${
selOwner.getClass.getName
})",
)
val fieldidx = fieldSelect(selCls, selSym)
val objRef = ref.cast(lhsExpr, RefType(ctx.getType_!(selCls), nullable = false))
struct.set(fieldidx, objRef, rhsExpr)
case S(otherSym) =>
lastWords(
s"Expected resolved AssignField(...) expression to be a TermSymbol, but got $otherSym (${
otherSym.getClass.getName
})",
)
case S(selSym) =>
selSym.asTrm match
case S(fieldSym) =>
val selOwner = fieldSym.owner getOrElse
lastWords(s"Expected resolved AssignField(...) expression `$fieldSym` to have an owner")
val selCls = selOwner.asBlkMember getOrElse
lastWords(
s"Expected resolved class for AssignField(...) expression to be a BlockMemberSymbol, but got $selOwner (${
selOwner.getClass.getName
})",
)
val fieldidx = fieldSelect(selCls, fieldSym)
val objRef = ref.cast(lhsExpr, RefType(ctx.getType_!(selCls), nullable = false))
struct.set(fieldidx, objRef, rhsExpr)
case N =>
lastWords(
s"Expected resolved AssignField(...) expression to be a TermSymbol, but got $selSym (${
selSym.getClass.getName
})",
)
case N =>
errExpr(
Ls(
Expand Down Expand Up @@ -1300,7 +1355,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
val result = pss.foldRight(bod):
case (ps, block) =>
Return(Lambda(ps, block), false)
val (params, bodyWat, locals) = setupFunction(ps, result)
val (params, bodyWat, locals) = setupFunction(N, ps, result)
if sym.nameIsMeaningful then
val funcTy = ctx.addType(
sym = N,
Expand Down Expand Up @@ -1355,8 +1410,8 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
break(errUnimplExpr("auxParams.nonEmpty"))
if isSingletonObj && clsLikeDefn.parentPath.nonEmpty then
break(errUnimplExpr("parentPath.nonEmpty for object"))
if clsLikeDefn.methods.nonEmpty then
break(errUnimplExpr("methods.nonEmpty"))
if isSingletonObj && clsLikeDefn.methods.nonEmpty then
break(errUnimplExpr("methods.nonEmpty for object"))
if clsLikeDefn.companion.isDefined then
break(errUnimplExpr("companion.isDefined"))

Expand Down Expand Up @@ -1438,6 +1493,35 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
`export` = predeclaredCtor.`export`,
),
)

def overwriteMethod(
sym: BlockMemberSymbol,
methodParamLocals: Seq[Local],
ps: ParamList,
bod: Block,
): Unit =
val (params, bodyWat, locals) = setupFunction(S(clsLikeDefn.isym -> "this"), ps, bod)
val predeclaredMethod = ctx.getFuncInfo_!(sym)
ctx.addFunc(
S(sym),
FuncInfo(
id = predeclaredMethod.id,
typeUse = predeclaredMethod.typeUse,
params = methodParamLocals.zip(params.map(_._2)),
nResults = bodyWat.resultTypes.length,
locals = locals,
body = bodyWat,
`export` = predeclaredMethod.`export`,
),
)

clsLikeDefn.methods.foreach:
case methodDefn @ FunDefn(_, sym, _, Nil, bod) =>
overwriteMethod(sym, Seq(clsLikeDefn.isym), PlainParamList(Nil), bod)
case methodDefn @ FunDefn(_, sym, _, _ :: _ :: _, _) =>
break(errUnimplExpr("multi-parameter-list method"))
case methodDefn @ FunDefn(_, sym, _, ps :: Nil, bod) =>
overwriteMethod(sym, clsLikeDefn.isym +: ps.params.map(_.sym), ps, bod)
if isSingletonObj then
registerSingletonInit(clsLikeDefn, typeref)

Expand Down Expand Up @@ -1813,6 +1897,7 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
predeclareClassTags(ordered)
ordered.foreach(predeclareClassInit)
ordered.foreach(predeclareClassConstructor)
ordered.foreach(predeclareClassMethods)

// Compile the entry function under a dedicated local scope so that any temp locals introduced
// during codegen (e.g., via `local.tee`) are declared in the entry function.
Expand Down Expand Up @@ -1927,23 +2012,27 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
nonNestedScoped(t)(returningTerm)

def setupFunction(
thisParam: Opt[Local -> Str],
params: ParamList,
body: Block,
)(using Ctx, Raise, Scope): (Seq[WasmParam -> Str], Expr, Seq[(Local, Str)]) =
// Add a frame for `ctx.locals`
ctx.pushLocal()

val result = scope.nest givenIn:
val wasmThisParam = thisParam.toSeq.map: (sym, _) =>
val (_, thisVarName) = bindCtorThis(sym)
WasmParam(thisVarName, RefType.anyref) -> thisVarName
val wasmParams = params.params.map: p =>
val paramNme = scope.allocateName(p.sym)
val param = WasmParam(paramNme, RefType.anyref)
ctx.addLocal(p.sym)
param -> paramNme
val (wasmBody, locals) = block(body)
val paramSyms: Set[Local] = params.params.map(p => (p.sym: Local)).toSet
val paramSyms: Set[Local] = thisParam.iterator.map(_._1).toSet ++ params.params.map(p => (p.sym: Local))
val extraLocals = getExtraLocals.filterNot((locals.toSet ++ paramSyms).contains)
val localsWithNames = (locals ++ extraLocals).map(l => l -> scope.allocateOrGetName(l))
(wasmParams.toSeq, wasmBody, localsWithNames)
(wasmThisParam ++ wasmParams, wasmBody, localsWithNames)

// Restore `ctx.locals`
ctx.popLocal()
Expand Down
Loading
Loading