2525
2626import logging
2727from collections import defaultdict
28- from typing import TYPE_CHECKING
28+ from typing import TYPE_CHECKING , overload
2929
30- from typing_extensions import override
30+ from typing_extensions import Self , override
3131
3232import sumpy .symbolic as sym
3333
@@ -196,10 +196,18 @@ def assign_temp(self, name_base: str, expr: sym.Basic) -> str:
196196 new_name = self .symbol_generator (name_base ).name
197197 return self .add_assignment (new_name , expr , retain_name = False )
198198
199- def run_global_cse (
200- self ,
201- extra_exprs : Sequence [sym .Expr ] | None = None
202- ) -> Sequence [sym .Basic ]:
199+ @overload
200+ def run_global_cse (self , extra_exprs : None = None ) -> Self : ...
201+
202+ @overload
203+ def run_global_cse (self ,
204+ extra_exprs : Sequence [sym .Expr ]
205+ ) -> tuple [Self , Sequence [sym .Basic ]]: ...
206+
207+ def run_global_cse (self ,
208+ extra_exprs : Sequence [sym .Expr ] | None = None
209+ ) -> tuple [Self , Sequence [sym .Basic ]] | Self :
210+ orig_extra_exprs = extra_exprs
203211 if extra_exprs is None :
204212 extra_exprs = []
205213
@@ -219,33 +227,32 @@ def run_global_cse(
219227 # from sumpy.symbolic import checked_cse
220228
221229 from sumpy .cse import cse
222- new_assignments , new_exprs = cse (
230+ new_cse_assignments , new_exprs = cse (
223231 [* assign_exprs , * extra_exprs ],
224232 symbols = self .symbol_generator )
225233
226234 new_assign_exprs = new_exprs [:len (assign_exprs )]
227235 new_extra_exprs = new_exprs [len (assign_exprs ):]
228236
229- for name , new_expr in zip (assign_names , new_assign_exprs , strict = True ):
230- self .assignments [name ] = new_expr
237+ result_assignments : dict [str , sym .Basic ] = {}
231238
232- for name , value in new_assignments :
239+ for name , value in new_cse_assignments :
233240 assert isinstance (name , sym .Symbol )
234- self . add_assignment ( name .name , value )
241+ result_assignments [ name .name ] = value
235242
236- for name , new_expr in zip (assign_names , new_assign_exprs , strict = True ):
237- # We want the assignment collection to be ordered correctly
238- # to make it easier for loopy to schedule.
239- # Deleting the original assignments and adding them again
240- # makes them occur after the CSE'd expression preserving
241- # the order of operations.
242- del self .assignments [name ]
243- self .assignments [name ] = new_expr
243+ result_assignments = {
244+ ** result_assignments ,
245+ ** dict (zip (assign_names , new_assign_exprs , strict = True )),
246+ }
244247
245248 logger .info ("common subexpression elimination: done after %.2f s" ,
246249 time .time () - start_time )
247250
248- return new_extra_exprs
251+ result = type (self )(result_assignments )
252+ if orig_extra_exprs is None :
253+ return result
254+ else :
255+ return result , new_extra_exprs
249256
250257# }}}
251258
0 commit comments