Skip to content

Commit d08014c

Browse files
LPTKCopilotCAG2Mark
authored
Fix nondeterministic private field ordering in Lifter (#459)
The private fields generated by the lifter had nondeterministic ordering because they were built by iterating over Map values derived from Set operations. Scala's Set and Map do not guarantee stable iteration order. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: CAG2Mark <git@markng.com>
1 parent bb2388a commit d08014c

14 files changed

Lines changed: 71 additions & 97 deletions

File tree

hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -738,9 +738,9 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
738738
/** Maps definition symbols to the path representing that definition. */
739739
protected val passedDefnsMap: Map[DefinitionSymbol[?], DefnRef]
740740

741-
protected lazy val capturesOrdered: List[ScopedInfo]
741+
protected lazy val capturesOrdered: List[ScopedInfo] = reqCaptures.toList.sorted
742742
protected final lazy val passedSymsOrdered: List[Local] = reqPassedSymbols.toList.sortBy(_.uid)
743-
protected final lazy val passedDefnsOrdered: List[DefinitionSymbol[?]] = reqDefns.toList.sortBy(_.uid)
743+
protected final lazy val reqDefnsOrdered: List[DefinitionSymbol[?]] = reqDefns.toList.sortBy(_.uid)
744744

745745
override lazy val capturePaths: Map[ScopedInfo, Path] =
746746
if thisCapturedLocals.isEmpty then capSymsMap
@@ -776,7 +776,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
776776
defnPathsFromThisObj ++ fromParents
777777

778778
final def formatArgs: List[Arg] =
779-
val defnsArgs = passedDefnsOrdered.map(d => ctx.defnsMap(d).asArg)
779+
val defnsArgs = reqDefnsOrdered.map(d => ctx.defnsMap(d).asArg)
780780
val captureArgs = capturesOrdered.map(c => ctx.capturesMap(c).asArg)
781781
val localArgs = passedSymsOrdered.map(l => ctx.symbolsMap(l).asArg)
782782
defnsArgs ::: captureArgs ::: localArgs
@@ -789,7 +789,8 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
789789
sealed trait GenericRewrittenScope[T] extends RewrittenScope[T]:
790790
lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap"))
791791
override lazy val capturePath = captureSym.asPath
792-
protected val liftedObjsSyms: Map[InnerSymbol, VarSymbol] = node.liftedObjSyms.map: s =>
792+
protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid)
793+
protected val liftedObjsSyms: Map[InnerSymbol, VarSymbol] = liftedObjsOrdered.map: s =>
793794
s -> VarSymbol(Tree.Ident(s.nme + "$"))
794795
.toMap
795796
override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map:
@@ -803,11 +804,14 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
803804
sealed trait ClsLikeRewrittenScope[T](sym: InnerSymbol) extends RewrittenScope[T]:
804805
lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap"))
805806
override lazy val capturePath = captureSym.asPath
806-
protected val liftedObjsSyms: Map[InnerSymbol, TermSymbol] = node.liftedObjSyms.map: s =>
807+
protected val liftedObjsOrdered: List[InnerSymbol] = node.liftedObjSyms.toList.sortBy(_.uid)
808+
protected val liftedObjsSyms: Map[InnerSymbol, TermSymbol] = liftedObjsOrdered.map: s =>
807809
s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$"))
808810
.toMap
809811
override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map:
810812
case k -> v => k -> v.asLocalPath
813+
protected def appendCaptureField(privFields: List[TermSymbol]) =
814+
if hasCapture then captureSym :: privFields else privFields
811815
protected def rewriteMethods(node: ScopeNode, methods: List[FunDefn])(using ctx: LifterCtxNew) =
812816
val mtds = node.children
813817
.map: c =>
@@ -867,19 +871,19 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
867871

868872
private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap"))
869873
override lazy val capturePath: Path = captureSym.asPath
870-
874+
871875
override def rewriteImpl: LifterResult[ClsLikeDefn] =
872876
val rewriterCtor = new BlockRewriter
873877
val rewriterPreCtor = new BlockRewriter
874878
val rewrittenCtor = rewriterCtor.rewrite(obj.cls.ctor)
875879
val rewrittenPrector = rewriterPreCtor.rewrite(obj.cls.preCtor)
876880
val ctorWithCap = addExtraSyms(rewrittenCtor, captureSym, Nil, false)
877-
881+
878882
val LifterResult(newMtds, extras) = rewriteMethods(node, obj.cls.methods)
879883
val newCls = obj.cls.copy(
880884
ctor = ctorWithCap,
881885
preCtor = rewrittenPrector,
882-
privateFields = captureSym :: liftedObjsSyms.values.toList ::: obj.cls.privateFields,
886+
privateFields = appendCaptureField(liftedObjsOrdered.map(liftedObjsSyms) ::: obj.cls.privateFields),
883887
methods = newMtds,
884888
)(obj.cls.configOverride)
885889
LifterResult(newCls, rewriterCtor.extraDefns.toList ::: rewriterPreCtor.extraDefns.toList ::: extras)
@@ -898,32 +902,30 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
898902
val LifterResult(newMtds, extras) = rewriteMethods(node, obj.clsBody.methods)
899903
val newComp = obj.clsBody.copy(
900904
ctor = ctorWithCap,
901-
privateFields = captureSym :: liftedObjsSyms.values.toList ::: obj.clsBody.privateFields,
905+
privateFields = appendCaptureField(liftedObjsOrdered.map(liftedObjsSyms) ::: obj.clsBody.privateFields),
902906
methods = newMtds
903907
)
904908
LifterResult(newComp, rewriterCtor.extraDefns.toList ::: extras)
905909

906910
class LiftedFunc(override val obj: ScopedObject.Func)(using ctx: LifterCtxNew) extends LiftedScope[FunDefn](obj) with GenericRewrittenScope[FunDefn]:
907-
private val passedSymsMap_ : Map[Local, VarSymbol] = passedSyms.map: s =>
911+
private val passedSymsMap_ : Map[Local, VarSymbol] = passedSymsOrdered.map: s =>
908912
s -> VarSymbol(Tree.Ident(s.nme))
909913
.toMap
910-
private val capSymsMap_ : Map[ScopedInfo, VarSymbol] = reqCaptures.map: i =>
914+
private val capSymsMap_ : Map[ScopedInfo, VarSymbol] = capturesOrdered.map: i =>
911915
val nme = data.getNode(i).obj.nme
912916
i -> VarSymbol(Tree.Ident(nme + "$cap"))
913917
.toMap
914-
private val defnSymsMap_ : Map[DefinitionSymbol[?], VarSymbol] = reqDefns.map: i =>
918+
private val defnSymsMap_ : Map[DefinitionSymbol[?], VarSymbol] = reqDefnsOrdered.sortBy(_.uid).map: i =>
915919
val nme = data.getNode(i).obj.nme
916920
i -> VarSymbol(Tree.Ident(nme + "$"))
917921
.toMap
918922

919-
override lazy val capturesOrdered: List[ScopedInfo] = reqCaptures.toList.sortBy(c => capSymsMap_(c).uid)
920-
921923
override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.asLocalPath).toMap
922924
override protected val capSymsMap = capSymsMap_.view.mapValues(_.asPath).toMap
923925
override protected val passedDefnsMap = defnSymsMap_.view.mapValues(_.asDefnRef).toMap
924926

925927
val auxParams: List[Param] =
926-
(passedDefnsOrdered.map(defnSymsMap_) ::: capturesOrdered.map(capSymsMap_) ::: passedSymsOrdered.map(passedSymsMap_))
928+
(reqDefnsOrdered.map(defnSymsMap_) ::: capturesOrdered.map(capSymsMap_) ::: passedSymsOrdered.map(passedSymsMap_))
927929
.map: s =>
928930
val decl = Param(FldFlags.empty.copy(isVal = false), s, N, Modulefulness.none)
929931
s.decl = S(decl)
@@ -1017,41 +1019,41 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
10171019
private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap"))
10181020
override lazy val capturePath: Path = captureSym.asPath
10191021

1020-
private val passedSymsMap_ : Map[Local, (vs: VarSymbol, ts: TermSymbol)] = passedSyms.map: s =>
1022+
private val passedSymsMap_ : Map[Local, (vs: VarSymbol, ts: TermSymbol)] = passedSymsOrdered.map: s =>
10211023
s ->
10221024
(
10231025
VarSymbol(Tree.Ident(s.nme)),
10241026
TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(s.nme))
10251027
)
10261028
.toMap
1027-
private val capSymsMap_ : Map[ScopedInfo, (vs: VarSymbol, ts: TermSymbol)] = reqCaptures.map: i =>
1029+
private val capSymsMap_ : Map[ScopedInfo, (vs: VarSymbol, ts: TermSymbol)] = capturesOrdered.map: i =>
10281030
val nme = data.getNode(i).obj.nme + "$cap"
10291031
i ->
10301032
(
10311033
VarSymbol(Tree.Ident(nme)),
10321034
TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(nme))
10331035
)
10341036
.toMap
1035-
private val defnSymsMap_ : Map[DefinitionSymbol[?], (vs: VarSymbol, ts: TermSymbol)] = reqDefns.map: i =>
1037+
private val defnSymsMap_ : Map[DefinitionSymbol[?], (vs: VarSymbol, ts: TermSymbol)] = reqDefnsOrdered.map: i =>
10361038
i ->
10371039
(
10381040
VarSymbol(Tree.Ident(i.nme + "$")),
10391041
TermSymbol(syntax.LetBind, S(obj.cls.isym), Tree.Ident(i.nme + "$"))
10401042
)
10411043
.toMap
10421044

1043-
private val extraPrivSyms =
1044-
liftedObjsSyms.values ++ passedSymsMap_.values.map(_.ts)
1045-
++ capSymsMap_.values.map(_.ts) ++ defnSymsMap_.values.map(_.ts)
1046-
1047-
override lazy val capturesOrdered: List[ScopedInfo] = reqCaptures.toList.sortBy(c => capSymsMap_(c).vs.uid)
1045+
private lazy val extraPrivSyms: List[TermSymbol] =
1046+
liftedObjsOrdered.map(liftedObjsSyms)
1047+
::: reqDefnsOrdered.map(defnSymsMap_(_).ts)
1048+
::: capturesOrdered.map(capSymsMap_(_).ts)
1049+
::: passedSymsOrdered.map(passedSymsMap_(_).ts)
10481050

10491051
override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.ts.asLocalPath).toMap
10501052
override protected val capSymsMap = capSymsMap_.view.mapValues(_.ts.asPath).toMap
10511053
override protected val passedDefnsMap = defnSymsMap_.view.mapValues(_.ts.asDefnRef).toMap
10521054

10531055
val auxParams: List[Param] =
1054-
(passedDefnsOrdered.map(x => defnSymsMap_(x).vs)
1056+
(reqDefnsOrdered.map(x => defnSymsMap_(x).vs)
10551057
::: capturesOrdered.map(x => capSymsMap_(x).vs)
10561058
::: passedSymsOrdered.map(x => passedSymsMap_(x).vs))
10571059
.map(Param.simple(_))
@@ -1158,7 +1160,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
11581160
case (sym, acc) =>
11591161
val (vs, ts) = capSymsMap_(sym)
11601162
Assign(ts, vs.asPath, acc)
1161-
val ctorWithDefns = passedDefnsOrdered.foldRight(ctorWithCaps):
1163+
val ctorWithDefns = reqDefnsOrdered.foldRight(ctorWithCaps):
11621164
case (sym, acc) =>
11631165
val (vs, ts) = defnSymsMap_(sym)
11641166
Assign(ts, vs.asPath, acc)
@@ -1168,12 +1170,13 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
11681170
else PlainParamList(auxParams) :: cls.auxParams
11691171

11701172
val LifterResult(newMtds, extras) = rewriteMethods(node, obj.cls.methods)
1173+
11711174
val newCls = obj.cls.copy(
11721175
owner = N,
11731176
k = syntax.Cls, // turn objects into classes
11741177
ctor = ctorWithDefns,
11751178
preCtor = rewrittenPrector,
1176-
privateFields = captureSym :: extraPrivSyms.toList ::: obj.cls.privateFields,
1179+
privateFields = appendCaptureField(extraPrivSyms ::: obj.cls.privateFields),
11771180
methods = newMtds,
11781181
auxParams = newAuxList
11791182
)(obj.cls.configOverride)

hkmc2/shared/src/main/scala/hkmc2/codegen/ScopeData.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ object ScopeData:
4646
case _ => super.applyDefn(defn)
4747
type ScopedInfo = DefinitionSymbol[?] | LabelSymbol | ScopeUID | Unit
4848

49+
given Ordering[ScopedInfo] with
50+
def compare(x: ScopedInfo, y: ScopedInfo): Int = (x, y) match
51+
case (a: Symbol, b: Symbol) => Ordering[Uid[Symbol]].compare(a.uid, b.uid)
52+
case (_: Symbol, _) => -1
53+
case (_, _: Symbol) => 1
54+
case ((), ()) => 0
55+
case ((), _) => -1
56+
case (_, ()) => 1
57+
case (a: Int, b: Int) => a compare b // ScopeUID is Int inside ScopeData
58+
4959
// ScopeData requires the set of ignored scopes to compute certain things, but
5060
// the lifter requires the scope tree to generate the metadata. To solve this,
5161
// we generate the scope tree then populate the metadata later.

hkmc2/shared/src/test/mlscript/basics/ValMemberSymbols.mls

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ fun foo(x, y) =
128128
//│ end
129129
//│ };
130130
//│ define A⁶ as class A⁷(u) {
131-
//│ private val A$cap⁰;
132131
//│ private val y⁰;
133132
//│ val u⁰;
134133
//│ val v⁰;

hkmc2/shared/src/test/mlscript/codegen/FirstClassFunctionTransform.mls

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ x => x
2424
//│ return x
2525
//│ };
2626
//│ define Function$⁰ as class Function$¹ {
27-
//│ private val Function$$cap⁰;
2827
//│ method call⁰ = fun call¹(x) { return lambda¹(x) }
2928
//│ };
3029
//│ set tmp = new Function$¹();
@@ -37,7 +36,7 @@ x => x
3736
:noInline
3837
x => x
3938
//│ ╔══[INTERNAL ERROR] [BlockChecker] Invalid IR: symbol x⁰ is bound more than once
40-
//│ ║ l.38: x => x
39+
//│ ║ l.37: x => x
4140
//│ ╙── ^
4241
//│ = class anonymous
4342

@@ -75,7 +74,6 @@ bar(foo)
7574
//│ return f.call﹖(0)
7675
//│ };
7776
//│ define Function$² as class Function$³ {
78-
//│ private val Function$$cap¹;
7977
//│ method call² = fun call³(x) { return foo¹(x) }
8078
//│ };
8179
//│ set tmp = new Function$³();
@@ -98,12 +96,12 @@ let f = foo in bar(f)
9896
bar([foo].0)
9997
bar([foo, x => x].1)
10098
//│ ╔══[COMPILATION ERROR] Cannot determine if 0 is a function.
101-
//│ ║ l.98: bar([foo].0)
99+
//│ ║ l.96: bar([foo].0)
102100
//│ ╙── ^^
103101
//│ ═══[COMPILATION ERROR] Cannot determine if 0$__checkNotMethod is a function.
104102
//│ ═══[COMPILATION ERROR] Cannot determine if Error is a function.
105103
//│ ╔══[COMPILATION ERROR] Cannot determine if 1 is a function.
106-
//│ ║ l.99: bar([foo, x => x].1)
104+
//│ ║ l.97: bar([foo, x => x].1)
107105
//│ ╙── ^^
108106
//│ ═══[COMPILATION ERROR] Cannot determine if 1$__checkNotMethod is a function.
109107
//│ ═══[COMPILATION ERROR] Cannot determine if Error is a function.
@@ -120,7 +118,7 @@ bar([foo, x => x].(i))
120118
:ge
121119
[foo, x => x].(i)(0)
122120
//│ ╔══[COMPILATION ERROR] Cannot determine if the dynamic selection is a function object.
123-
//│ ║ l.121: [foo, x => x].(i)(0)
121+
//│ ║ l.119: [foo, x => x].(i)(0)
124122
//│ ╙── ^
125123

126124

@@ -134,10 +132,10 @@ let foo = Foo()
134132
foo.("f")(0)
135133
foo.("h")(1)
136134
//│ ╔══[COMPILATION ERROR] Cannot determine if the dynamic selection is a function object.
137-
//│ ║ l.134: foo.("f")(0)
135+
//│ ║ l.132: foo.("f")(0)
138136
//│ ╙── ^^^^^^^^
139137
//│ ╔══[COMPILATION ERROR] Cannot determine if the dynamic selection is a function object.
140-
//│ ║ l.135: foo.("h")(1)
138+
//│ ║ l.133: foo.("h")(1)
141139
//│ ╙── ^^^^^^^^
142140
//│ = 0
143141
//│ foo = Foo()
@@ -157,7 +155,6 @@ Foo(1).foo(x => x - 1)
157155
//│ return -⁰(x, 1)
158156
//│ };
159157
//│ define Function$⁴ as class Function$⁵ {
160-
//│ private val Function$$cap²;
161158
//│ method call⁴ = fun call⁵(x) { return lambda⁵(x) }
162159
//│ };
163160
//│ set tmp1 = new Function$⁵();
@@ -176,7 +173,6 @@ foo()(1)
176173
//│ return x
177174
//│ };
178175
//│ define Function$⁶ as class Function$⁷ {
179-
//│ private val Function$$cap³;
180176
//│ method call⁶ = fun call⁷(x) {
181177
//│ return lambda⁷(x)
182178
//│ }
@@ -248,7 +244,6 @@ module Foo with
248244
//│ return +⁰(tmp, z)
249245
//│ };
250246
//│ define Function$⁸ as class Function$⁹ {
251-
//│ private val Function$$cap⁴;
252247
//│ private val y⁰;
253248
//│ private val z⁰;
254249
//│ constructor(y, z) {
@@ -261,14 +256,10 @@ module Foo with
261256
//│ return lambda⁹(y⁰, z⁰, x)
262257
//│ }
263258
//│ };
264-
//│ define Foo¹ as class Foo² {
265-
//│ private val Foo$cap⁰;
266-
//│ }
259+
//│ define Foo¹ as class Foo²
267260
//│ module Foo³ {
268-
//│ private val Foo_mod$cap⁰;
269261
//│ constructor {
270262
//│ define Bar⁰ as class Bar²(y) {
271-
//│ private val Bar$cap⁰;
272263
//│ private val y¹;
273264
//│ constructor Bar¹ {
274265
//│ set y¹ = y;
@@ -304,7 +295,6 @@ y => x + y
304295
//│ return +⁰(x¹, y)
305296
//│ };
306297
//│ define Function$¹⁰ as class Function$¹¹ {
307-
//│ private val Function$$cap⁵;
308298
//│ method call¹⁰ = fun call¹¹(y) { return lambda¹¹(y) }
309299
//│ };
310300
//│ set tmp = new Function$¹¹();
@@ -399,7 +389,7 @@ foo.Foo#x
399389
:ge
400390
foo.f(0)
401391
//│ ╔══[COMPILATION ERROR] Cannot determine if f is a function object.
402-
//│ ║ l.400: foo.f(0)
392+
//│ ║ l.390: foo.f(0)
403393
//│ ╙── ^^^^^
404394
//│ = 1
405395

@@ -417,7 +407,7 @@ foo.Foo#f(0)
417407
:ge
418408
foo.x(0)
419409
//│ ╔══[COMPILATION ERROR] Cannot determine if x is a function object.
420-
//│ ║ l.418: foo.x(0)
410+
//│ ║ l.408: foo.x(0)
421411
//│ ╙── ^^^^^
422412

423413

@@ -505,7 +495,6 @@ fun foo(x)(y) = x + y
505495
//│ return +⁰(x, y)
506496
//│ };
507497
//│ define Function$¹² as class Function$¹³ {
508-
//│ private val Function$$cap⁶;
509498
//│ private val x²;
510499
//│ constructor(x) {
511500
//│ super⁰();
@@ -554,7 +543,6 @@ foo(tuple)
554543
//│ ———————————————| Lowered IR |———————————————————————————————————————————————————————————————————————
555544
//│ let Function$¹⁴, tmp;
556545
//│ define Function$¹⁴ as class Function$¹⁵ {
557-
//│ private val Function$$cap⁷;
558546
//│ method call¹⁴ = fun call¹⁵(...xs) {
559547
//│ return Predef⁰.tuple⁰()
560548
//│ }

hkmc2/shared/src/test/mlscript/dead-param-elim/class-in-fun.mls

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ f(1, 2).get()
1818
//│ ——————————————| Optimized IR |——————————————————————————————————————————————————————————————————————
1919
//│ let C⁰, f⁰, tmp, C$⁰;
2020
//│ define C⁰ as class C¹ {
21-
//│ private val C$cap⁰;
2221
//│ private val used⁰;
2322
//│ constructor(used) {
2423
//│ set used⁰ = used;

0 commit comments

Comments
 (0)