Skip to content

Commit 9920fd1

Browse files
Stabilize param-eq baseline and pin experimental commit
1 parent d2aac63 commit 9920fd1

35 files changed

Lines changed: 3794 additions & 8057 deletions

.cargo/config.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22
[build]
33
# debug symbols https://pyo3.rs/main/debugging#common-setup
44
rustflags = ["-g"]
5+
6+
[net]
7+
git-fetch-with-cli = true

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ opentelemetry = "0.28"
1818
opentelemetry-otlp = { version = "0.28", features = ["http-proto", "reqwest-blocking-client", "trace"] }
1919
opentelemetry-stdout = { version = "0.28", features = ["trace"] }
2020
opentelemetry_sdk = "0.28"
21-
egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "b5c211b9def133cad9540a11744e8a1e40bd2a29", default-features = false }
21+
egglog = { git = "https://github.com/egraphs-good/egglog.git", default-features = false, rev = "b5c211b9def133cad9540a11744e8a1e40bd2a29" }
2222
egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", rev = "b5c211b9def133cad9540a11744e8a1e40bd2a29" }
2323
egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", rev = "b5c211b9def133cad9540a11744e8a1e40bd2a29" }
2424
egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", rev = "b5c211b9def133cad9540a11744e8a1e40bd2a29" }
2525
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", rev = "b5c211b9def133cad9540a11744e8a1e40bd2a29" }
2626

2727

28-
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental.git", rev = "fae7440e67497fb7ac56c889dde9eedec17636f9", default-features = false }
28+
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental.git", default-features = false, rev = "3f38efab7307b765bdb912b81e99736f27e00b1f" }
2929
egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] }
3030
serde_json = "1"
3131
pyo3-log = "*"

python/egglog/bindings.pyi

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ __all__ = [
7878
"RustSpan",
7979
"Saturate",
8080
"Scan",
81-
"SchedulerHandle",
8281
"Schema",
8382
"Sequence",
8483
"SerializedEGraph",
@@ -136,22 +135,6 @@ class EGraph:
136135
def run_program(
137136
self, *commands: _Command, traceparent: str | None = None, tracestate: str | None = None
138137
) -> list[_CommandOutput]: ...
139-
def add_backoff_scheduler(
140-
self,
141-
match_limit: int,
142-
ban_length: int,
143-
*,
144-
egg_like: bool = False,
145-
haskell_backoff: bool = False,
146-
) -> SchedulerHandle: ...
147-
def run_ruleset_with_scheduler(
148-
self,
149-
ruleset: str,
150-
scheduler: SchedulerHandle,
151-
*,
152-
traceparent: str | None = None,
153-
tracestate: str | None = None,
154-
) -> RunReport: ...
155138
def serialize(
156139
self,
157140
root_eclasses: list[_Expr],
@@ -193,7 +176,6 @@ class Value:
193176
def __ge__(self, other: object) -> bool: ...
194177

195178
@final
196-
class SchedulerHandle: ...
197179

198180
@final
199181
class EggSmolError(Exception):

python/egglog/declarations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,6 @@ def visit(typed_expr: TypedExprDecl) -> None:
10311031
@dataclass(frozen=True)
10321032
class SaturateDecl:
10331033
schedule: ScheduleDecl
1034-
stop_when_no_updates: bool = False
10351034

10361035

10371036
@dataclass(frozen=True)
@@ -1066,7 +1065,8 @@ class BackOffDecl:
10661065
id: UUID
10671066
match_limit: int | None
10681067
ban_length: int | None
1069-
egg_like: bool = False
1068+
fresh_rematch: bool = False
1069+
persistent: bool = False
10701070

10711071

10721072
##

python/egglog/egraph.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tempfile
77
from collections.abc import Callable, Generator, Iterable
88
from contextvars import ContextVar, Token
9-
from dataclasses import InitVar, dataclass, field
9+
from dataclasses import InitVar, dataclass, field, replace
1010
from functools import partial
1111
from inspect import Parameter, currentframe, getmodule, signature
1212
from types import FrameType, FunctionType
@@ -976,31 +976,6 @@ def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
976976
assert isinstance(command_output, bindings.RunScheduleOutput)
977977
return command_output.report
978978

979-
def _add_backoff_scheduler(
980-
self,
981-
*,
982-
match_limit: int,
983-
ban_length: int,
984-
egg_like: bool,
985-
haskell_backoff: bool = False,
986-
) -> bindings.SchedulerHandle:
987-
return self._egraph.add_backoff_scheduler(
988-
match_limit,
989-
ban_length,
990-
egg_like=egg_like,
991-
haskell_backoff=haskell_backoff,
992-
)
993-
994-
def _run_ruleset_with_scheduler(
995-
self,
996-
ruleset: Ruleset | UnstableCombinedRuleset,
997-
scheduler: bindings.SchedulerHandle,
998-
) -> bindings.RunReport:
999-
self._add_decls(ruleset)
1000-
ruleset_ident = ruleset.__egg_ident__
1001-
self._state.ruleset_to_egg(ruleset_ident)
1002-
return call_with_current_trace(self._egraph.run_ruleset_with_scheduler, str(ruleset_ident), scheduler)
1003-
1004979
def stats(self) -> bindings.RunReport:
1005980
"""
1006981
Returns the overall run report for the egraph.
@@ -1340,7 +1315,11 @@ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
13401315
"""
13411316
(output,) = self._run_program(bindings.PrintSize(span(1), None))
13421317
assert isinstance(output, bindings.PrintAllFunctionsSize)
1343-
return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))]
1318+
return [
1319+
(callables[0], size)
1320+
for (name, size) in output.sizes
1321+
if (callables := self._egg_fn_to_callables(name))
1322+
]
13441323

13451324
def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]:
13461325
return [
@@ -1603,14 +1582,11 @@ def __mul__(self, length: int) -> Schedule:
16031582
"""
16041583
return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length))
16051584

1606-
def saturate(self, *, stop_when_no_updates: bool = False) -> Schedule:
1585+
def saturate(self) -> Schedule:
16071586
"""
16081587
Run the schedule until the e-graph is saturated.
16091588
"""
1610-
return Schedule(
1611-
self.__egg_decls_thunk__,
1612-
SaturateDecl(self.schedule, stop_when_no_updates=stop_when_no_updates),
1613-
)
1589+
return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule))
16141590

16151591
def __add__(self, other: Schedule) -> Schedule:
16161592
"""
@@ -2112,7 +2088,7 @@ def back_off(
21122088
match_limit: None | int = None,
21132089
ban_length: None | int = None,
21142090
*,
2115-
egg_like: bool = False,
2091+
fresh_rematch: bool = False,
21162092
) -> BackOff:
21172093
"""
21182094
Create a backoff scheduler configuration.
@@ -2121,16 +2097,29 @@ def back_off(
21212097
schedule = run(analysis_ruleset).saturate() + run(ruleset, scheduler=back_off(match_limit=1000, ban_length=5)) * 10
21222098
```
21232099
This will run the `analysis_ruleset` until saturation, then run `ruleset` 10 times,
2124-
using a backoff scheduler. Set `egg_like=True` to use the fresh-rematch variant
2100+
using a backoff scheduler. Set `fresh_rematch=True` to use the fresh-rematch variant
21252101
that is closer to `egg`/`hegg`; the default keeps egglog's backlog behavior.
21262102
"""
2127-
return BackOff(BackOffDecl(id=uuid4(), match_limit=match_limit, ban_length=ban_length, egg_like=egg_like))
2103+
return BackOff(
2104+
BackOffDecl(
2105+
id=uuid4(),
2106+
match_limit=match_limit,
2107+
ban_length=ban_length,
2108+
fresh_rematch=fresh_rematch,
2109+
)
2110+
)
21282111

21292112

21302113
@dataclass(frozen=True)
21312114
class BackOff:
21322115
scheduler: BackOffDecl
21332116

2117+
def persistent(self) -> BackOff:
2118+
"""
2119+
Reuse this scheduler across repeated runs on the same egraph.
2120+
"""
2121+
return BackOff(replace(self.scheduler, persistent=True))
2122+
21342123
def scope(self, schedule: Schedule) -> Schedule:
21352124
"""
21362125
Defines the scheduler to be created directly before the inner schedule, instead of the default which is at the

python/egglog/egraph_state.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
_TRACER = trace.get_tracer(__name__)
3131

3232

33+
def _normalize_global_let_name(name: str) -> str:
34+
return name if name.startswith("$") else f"${name}"
35+
36+
3337
def span(frame_index: int = 0) -> bindings.RustSpan:
3438
"""
3539
Returns a span for the current file and line.
@@ -44,10 +48,6 @@ def span(frame_index: int = 0) -> bindings.RustSpan:
4448
return bindings.RustSpan("", 0, 0)
4549

4650

47-
def _normalize_global_let_name(name: str) -> str:
48-
return name if name.startswith("$") else f"${name}"
49-
50-
5151
@dataclass
5252
class EGraphState:
5353
"""
@@ -107,6 +107,14 @@ def copy(self) -> EGraphState:
107107
def _run_program(self, *commands: bindings._Command) -> list[bindings._CommandOutput]:
108108
return call_with_current_trace(self.egraph.run_program, *commands)
109109

110+
@staticmethod
111+
def _persistent_scheduler_name(scheduler: BackOffDecl) -> str:
112+
return f"_persistent_scheduler_{scheduler.id.hex}"
113+
114+
@staticmethod
115+
def _local_scheduler_name(index: int) -> str:
116+
return f"_scheduler_{index}"
117+
110118
@_TRACER.start_as_current_span("run_schedule_to_egg")
111119
def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
112120
"""
@@ -115,17 +123,19 @@ def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
115123
If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise
116124
will be a normal run command.
117125
"""
118-
processed_schedule = self._process_schedule(schedule)
126+
processed_schedule, persistent_schedulers = self._process_schedule(schedule)
119127
if processed_schedule is None:
120128
return bindings.RunSchedule(self._schedule_to_egg(schedule))
129+
for scheduler in persistent_schedulers:
130+
self._run_program(self._persistent_scheduler_to_egg(scheduler))
121131
top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, [])
122132
if len(top_level_schedules) == 1:
123133
schedule_expr = top_level_schedules[0]
124134
else:
125135
schedule_expr = bindings.Call(span(), "seq", top_level_schedules)
126136
return bindings.UserDefined(span(), "run-schedule", [schedule_expr])
127137

128-
def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
138+
def _process_schedule(self, schedule: ScheduleDecl) -> tuple[ScheduleDecl | None, tuple[BackOffDecl, ...]]:
129139
"""
130140
Processes a schedule to determine if it contains any custom schedulers.
131141
@@ -134,19 +144,23 @@ def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
134144
135145
Also processes all rulesets in the schedule to make sure they are registered.
136146
"""
137-
bound_schedulers: list[UUID] = []
147+
bound_schedulers: list[BackOffDecl] = []
138148
unbound_schedulers: list[BackOffDecl] = []
149+
persistent_schedulers: dict[UUID, BackOffDecl] = {}
139150

140151
def helper(s: ScheduleDecl) -> None:
141152
match s:
142153
case LetSchedulerDecl(scheduler, inner):
143-
bound_schedulers.append(scheduler.id)
154+
bound_schedulers.append(scheduler)
144155
return helper(inner)
145156
case RunDecl(ruleset_name, _, scheduler):
146157
self.ruleset_to_egg(ruleset_name)
147-
if scheduler and scheduler.id not in bound_schedulers:
148-
unbound_schedulers.append(scheduler)
149-
case SaturateDecl(inner, _) | RepeatDecl(inner, _):
158+
if scheduler and scheduler.id not in {s.id for s in bound_schedulers}:
159+
if scheduler.persistent:
160+
persistent_schedulers[scheduler.id] = scheduler
161+
else:
162+
unbound_schedulers.append(scheduler)
163+
case SaturateDecl(inner) | RepeatDecl(inner, _):
150164
return helper(inner)
151165
case SequenceDecl(schedules):
152166
for sc in schedules:
@@ -156,16 +170,16 @@ def helper(s: ScheduleDecl) -> None:
156170
return None
157171

158172
helper(schedule)
159-
if not bound_schedulers and not unbound_schedulers:
160-
return None
173+
if not bound_schedulers and not unbound_schedulers and not persistent_schedulers:
174+
return None, ()
161175
for scheduler in unbound_schedulers:
162176
schedule = LetSchedulerDecl(scheduler, schedule)
163-
return schedule
177+
return schedule, tuple(persistent_schedulers.values())
164178

165179
def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
166180
msg = "Should never reach this, let schedulers should be handled by custom scheduler"
167181
match schedule:
168-
case SaturateDecl(schedule, _):
182+
case SaturateDecl(schedule):
169183
return bindings.Saturate(span(), self._schedule_to_egg(schedule))
170184
case RepeatDecl(schedule, times):
171185
return bindings.Repeat(span(), times, self._schedule_to_egg(schedule))
@@ -184,33 +198,40 @@ def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
184198
assert_never(schedule)
185199

186200
def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
187-
self, schedule: ScheduleDecl, bound_schedulers: list[UUID]
201+
self, schedule: ScheduleDecl, bound_schedulers: list[BackOffDecl]
188202
) -> list[bindings._Expr]:
189203
"""
190204
Turns a scheduler into an egg expression, to be used with a custom extract command.
191205
192206
The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`.
193207
"""
194208
match schedule:
195-
case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length, egg_like), inner):
196-
name = f"_scheduler_{len(bound_schedulers)}"
197-
bound_schedulers.append(id)
209+
case LetSchedulerDecl(scheduler, inner):
210+
match_limit = scheduler.match_limit
211+
ban_length = scheduler.ban_length
212+
fresh_rematch = scheduler.fresh_rematch
213+
name = self._local_scheduler_name(len(bound_schedulers))
214+
bound_schedulers.append(scheduler)
198215
args: list[bindings._Expr] = []
199216
if match_limit is not None:
200217
args.append(bindings.Var(span(), ":match-limit"))
201218
args.append(bindings.Lit(span(), bindings.Int(match_limit)))
202219
if ban_length is not None:
203220
args.append(bindings.Var(span(), ":ban-length"))
204221
args.append(bindings.Lit(span(), bindings.Int(ban_length)))
205-
scheduler_name = "back-off-egg" if egg_like else "back-off"
222+
scheduler_name = "back-off-fresh" if fresh_rematch else "back-off"
206223
back_off_decl = bindings.Call(span(), scheduler_name, args)
207224
let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl])
208225
return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)]
209226
case RunDecl(ruleset_ident, until, scheduler):
210227
args = [bindings.Var(span(), str(ruleset_ident))]
211228
if scheduler:
212229
name = "run-with"
213-
scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}"
230+
scheduler_name = self._persistent_scheduler_name(scheduler)
231+
for i, bound in enumerate(bound_schedulers):
232+
if bound.id == scheduler.id:
233+
scheduler_name = self._local_scheduler_name(i)
234+
break
214235
args.insert(0, bindings.Var(span(), scheduler_name))
215236
else:
216237
name = "run"
@@ -225,10 +246,8 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
225246
raise ValueError(msg)
226247
args.append(fact_egg.expr)
227248
return [bindings.Call(span(), name, args)]
228-
case SaturateDecl(inner, stop_when_no_updates):
249+
case SaturateDecl(inner):
229250
args = self._schedule_with_scheduler_to_egg(inner, bound_schedulers)
230-
if stop_when_no_updates:
231-
args = [bindings.Var(span(), ":stop-when-no-updates"), *args]
232251
return [bindings.Call(span(), "saturate", args)]
233252
case RepeatDecl(inner, times):
234253
return [
@@ -249,6 +268,22 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
249268
case _:
250269
assert_never(schedule)
251270

271+
def _persistent_scheduler_to_egg(self, scheduler: BackOffDecl) -> bindings._Command:
272+
args: list[bindings._Expr] = []
273+
if scheduler.match_limit is not None:
274+
args.append(bindings.Var(span(), ":match-limit"))
275+
args.append(bindings.Lit(span(), bindings.Int(scheduler.match_limit)))
276+
if scheduler.ban_length is not None:
277+
args.append(bindings.Var(span(), ":ban-length"))
278+
args.append(bindings.Lit(span(), bindings.Int(scheduler.ban_length)))
279+
scheduler_name = "back-off-fresh" if scheduler.fresh_rematch else "back-off"
280+
back_off_decl = bindings.Call(span(), scheduler_name, args)
281+
return bindings.UserDefined(
282+
span(),
283+
"let-scheduler",
284+
[bindings.Var(span(), self._persistent_scheduler_name(scheduler)), back_off_decl],
285+
)
286+
252287
def ruleset_to_egg(self, ident: Ident) -> None:
253288
"""
254289
Registers a ruleset if it's not already registered.

0 commit comments

Comments
 (0)