@@ -229,7 +229,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int):
229229 chunk_size = len (chunk )
230230
231231 # Condition to start a new block
232- if (row_count + chunk_size ) > (data_block_size * cpu_count () / 2 ):
232+ if (row_count + chunk_size ) > (data_block_size * cpu_count ()):
233233 row_count = 0
234234 block_count += 1
235235 blocks [block_count ] = {}
@@ -262,9 +262,17 @@ def deriv_block(
262262 return stacked_block
263263
264264
265- def update_deriv_blocks (address , indices , derivatives , solve ):
266- columns , local_ind = indices [address ]
267- derivatives [address ][:, local_ind ] = solve [:, columns ]
265+ def update_deriv_blocks (address , tInd , indices , derivatives , solve , shape ):
266+ if address not in derivatives :
267+ deriv_array = np .zeros (shape )
268+ else :
269+ deriv_array = derivatives [address ].compute ()
270+
271+ if address in indices :
272+ columns , local_ind = indices [address ]
273+ deriv_array [:, local_ind ] = solve [:, columns ]
274+
275+ derivatives [address ] = delayed (deriv_array )
268276
269277
270278def get_field_deriv_block (
@@ -298,44 +306,41 @@ def get_field_deriv_block(
298306 local_ind ,
299307 )
300308 count += len (sub_ind )
309+ deriv_comp = deriv_block (
310+ s_id ,
311+ r_id ,
312+ b_id ,
313+ ATinv_df_duT_v ,
314+ Asubdiag ,
315+ local_ind ,
316+ sub_ind ,
317+ simulation ,
318+ tInd ,
319+ )
301320
302321 stacked_blocks .append (
303- deriv_block (
304- s_id ,
305- r_id ,
306- b_id ,
307- ATinv_df_duT_v ,
308- Asubdiag ,
309- local_ind ,
310- sub_ind ,
311- simulation ,
312- tInd ,
322+ array .from_delayed (
323+ deriv_comp ,
324+ dtype = float ,
325+ shape = (
326+ simulation .field_derivs [tInd ][s_id ][r_id ].shape [0 ],
327+ len (local_ind ),
328+ ),
313329 )
314330 )
315-
316331 if len (stacked_blocks ) > 0 :
317- solve = AdiagTinv * np .hstack (dask .compute (stacked_blocks )[0 ])
332+ blocks = array .hstack (stacked_blocks ).compute ()
333+ solve = AdiagTinv * blocks
318334
319335 update_list = []
320- for s_id , r_id , b_id in block :
321- if (s_id , r_id , b_id ) not in ATinv_df_duT_v :
322- ATinv_df_duT_v [(s_id , r_id , b_id )] = np .zeros (
323- (
324- simulation .field_derivs [tInd ][s_id ][r_id ].shape [0 ],
325- len (block [(s_id , r_id , b_id )][0 ]),
326- )
327- )
328-
329- if (s_id , r_id , b_id ) in indices :
330- update_list .append (
331- update_deriv_blocks (
332- (s_id , r_id , b_id ),
333- indices ,
334- ATinv_df_duT_v ,
335- solve ,
336- )
337- )
338-
336+ for address in block :
337+ shape = (
338+ simulation .field_derivs [tInd ][address [0 ]][address [1 ]].shape [0 ],
339+ len (block [address ][0 ]),
340+ )
341+ update_list .append (
342+ update_deriv_blocks (address , tInd , indices , ATinv_df_duT_v , solve , shape )
343+ )
339344 dask .compute (update_list )
340345
341346 return ATinv_df_duT_v
@@ -395,7 +400,7 @@ def compute_J(self, f=None, Ainv=None):
395400 f , Ainv = self .fields (self .model , return_Ainv = True )
396401
397402 ftype = self ._fieldType + "Solution"
398- Jmatrix = np .zeros ((self .survey .nD , self .model .size ), dtype = np .float32 )
403+ Jmatrix = delayed ( np .zeros ((self .survey .nD , self .model .size ), dtype = np .float32 ) )
399404 simulation_times = np .r_ [0 , np .cumsum (self .time_steps )] + self .t0
400405 data_times = self .survey .source_list [0 ].receiver_list [0 ].times
401406 blocks = get_parallel_blocks (
@@ -427,17 +432,15 @@ def compute_J(self, f=None, Ainv=None):
427432 time_mask ,
428433 )
429434 )
430-
431435 dask .compute (j_row_updates )
432-
433436 for A in Ainv .values ():
434437 A .clean ()
435438
436439 if self .store_sensitivities == "disk" :
437440 del Jmatrix
438441 return array .from_zarr (self .sensitivity_path + f"J.zarr" )
439442 else :
440- return Jmatrix
443+ return Jmatrix . compute ()
441444
442445
443446Sim .compute_J = compute_J
0 commit comments