Skip to content

Commit 99c620f

Browse files
committed
SymbolicAssignmentCollection.run_global_cse: do not modify in-place
1 parent 5a533a4 commit 99c620f

7 files changed

Lines changed: 36 additions & 29 deletions

File tree

benchmarks/bench_translations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def track_m2l_op_count(self, param):
8282
dvec, tgt_rscale)
8383
for i, expr in enumerate(result):
8484
sac.assign_unique(f"coeff{i}", expr)
85-
sac.run_global_cse()
85+
sac = sac.run_global_cse()
8686
insns = to_loopy_insns(sac.assignments.items())
8787
counter = pymbolic.mapper.flop_counter.CSEAwareFlopCounter()
8888

sumpy/assignment_collection.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
import logging
2727
from collections import defaultdict
28-
from typing import TYPE_CHECKING
28+
from typing import TYPE_CHECKING, overload
2929

30-
from typing_extensions import override
30+
from typing_extensions import Self, override
3131

3232
import sumpy.symbolic as sym
3333

@@ -196,10 +196,18 @@ def assign_temp(self, name_base: str, expr: sym.Basic) -> str:
196196
new_name = self.symbol_generator(name_base).name
197197
return self.add_assignment(new_name, expr, retain_name=False)
198198

199-
def run_global_cse(
200-
self,
201-
extra_exprs: Sequence[sym.Expr] | None = None
202-
) -> Sequence[sym.Basic]:
199+
@overload
200+
def run_global_cse(self, extra_exprs: None = None) -> Self: ...
201+
202+
@overload
203+
def run_global_cse(self,
204+
extra_exprs: Sequence[sym.Expr]
205+
) -> tuple[Self, Sequence[sym.Basic]]: ...
206+
207+
def run_global_cse(self,
208+
extra_exprs: Sequence[sym.Expr] | None = None
209+
) -> tuple[Self, Sequence[sym.Basic]] | Self:
210+
orig_extra_exprs = extra_exprs
203211
if extra_exprs is None:
204212
extra_exprs = []
205213

@@ -219,33 +227,32 @@ def run_global_cse(
219227
# from sumpy.symbolic import checked_cse
220228

221229
from sumpy.cse import cse
222-
new_assignments, new_exprs = cse(
230+
new_cse_assignments, new_exprs = cse(
223231
[*assign_exprs, *extra_exprs],
224232
symbols=self.symbol_generator)
225233

226234
new_assign_exprs = new_exprs[:len(assign_exprs)]
227235
new_extra_exprs = new_exprs[len(assign_exprs):]
228236

229-
for name, new_expr in zip(assign_names, new_assign_exprs, strict=True):
230-
self.assignments[name] = new_expr
237+
result_assignments: dict[str, sym.Basic] = {}
231238

232-
for name, value in new_assignments:
239+
for name, value in new_cse_assignments:
233240
assert isinstance(name, sym.Symbol)
234-
self.add_assignment(name.name, value)
241+
result_assignments[name.name] = value
235242

236-
for name, new_expr in zip(assign_names, new_assign_exprs, strict=True):
237-
# We want the assignment collection to be ordered correctly
238-
# to make it easier for loopy to schedule.
239-
# Deleting the original assignments and adding them again
240-
# makes them occur after the CSE'd expression preserving
241-
# the order of operations.
242-
del self.assignments[name]
243-
self.assignments[name] = new_expr
243+
result_assignments = {
244+
**result_assignments,
245+
**dict(zip(assign_names, new_assign_exprs, strict=True)),
246+
}
244247

245248
logger.info("common subexpression elimination: done after %.2f s",
246249
time.time() - start_time)
247250

248-
return new_extra_exprs
251+
result = type(self)(result_assignments)
252+
if orig_extra_exprs is None:
253+
return result
254+
else:
255+
return result, new_extra_exprs
249256

250257
# }}}
251258

sumpy/e2e.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_translation_loopy_insns(self):
121121
self.src_expansion, src_coeff_exprs, src_rscale,
122122
dvec=dvec, tgt_rscale=tgt_rscale, sac=sac))]
123123

124-
sac.run_global_cse()
124+
sac = sac.run_global_cse()
125125

126126
from sumpy.codegen import to_loopy_insns
127127
return to_loopy_insns(
@@ -177,7 +177,7 @@ def get_translation_loopy_insns(self):
177177
self.src_expansion, src_coeff_exprs, src_rscale,
178178
dvec, tgt_rscale, sac))]
179179

180-
sac.run_global_cse()
180+
sac = sac.run_global_cse()
181181

182182
from sumpy.codegen import to_loopy_insns
183183
return to_loopy_insns(
@@ -335,7 +335,7 @@ def get_translation_loopy_insns(self, result_dtype):
335335
m2l_translation_classes_dependent_data=(
336336
m2l_translation_classes_dependent_data)))]
337337

338-
sac.run_global_cse()
338+
sac = sac.run_global_cse()
339339

340340
from sumpy.codegen import to_loopy_insns
341341
return to_loopy_insns(

sumpy/expansion/loopy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def make_e2p_loopy_kernel(
9090
for i, knl in enumerate(kernels)
9191
}
9292

93-
sac.run_global_cse()
93+
sac = sac.run_global_cse()
9494

9595
code_transformers = (
9696
[expansion.get_code_transformer()]
@@ -194,7 +194,7 @@ def make_p2e_loopy_kernel(
194194
sac.add_assignment(f"coeffs{i}", coeff) for i, coeff in enumerate(coeffs)
195195
}
196196

197-
sac.run_global_cse()
197+
sac = sac.run_global_cse()
198198

199199
code_transformers = (
200200
[expansion.get_code_transformer()]

sumpy/expansion/m2l.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,7 @@ def loopy_translation_classes_dependent_data(
13241324
tgt_coeff_names = [
13251325
sac.assign_unique(f"m2l_translation_classes_dependent_data{i}", coeff_i)
13261326
for i, coeff_i in enumerate(derivatives)]
1327-
sac.run_global_cse()
1327+
sac = sac.run_global_cse()
13281328

13291329
from sumpy.codegen import to_loopy_insns
13301330
from sumpy.tools import to_complex_dtype

sumpy/p2p.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_loopy_insns_and_result_names(self):
147147
for i, expr in enumerate(exprs)
148148
]
149149

150-
sac.run_global_cse()
150+
sac = sac.run_global_cse()
151151

152152
from sumpy.codegen import to_loopy_insns
153153
loopy_insns = to_loopy_insns(sac.assignments.items(),

sumpy/qbx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_loopy_insns_and_result_names(self):
161161

162162
logger.info("compute expansion expressions: done")
163163

164-
sac.run_global_cse()
164+
sac = sac.run_global_cse()
165165

166166
pymbolic_expr_maps = [knl.get_code_transformer() for knl in [
167167
*self.target_kernels, *self.source_kernels]]

0 commit comments

Comments
 (0)