33
44from ....electromagnetics .time_domain .simulation import BaseTDEMSimulation as Sim
55from ....utils import Zero
6+ from multiprocessing import cpu_count
67import numpy as np
78import scipy .sparse as sp
89from time import time
@@ -105,7 +106,7 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields):
105106
106107Sim .dpred = dask_dpred
107108Sim .field_derivs = None
108- Sim . j_initialzer = None
109+
109110
110111
111112def compute_J (self , f = None , Ainv = None ):
@@ -117,6 +118,7 @@ def compute_J(self, f=None, Ainv=None):
117118 row_chunks = int (np .ceil (
118119 float (self .survey .nD ) / np .ceil (float (m_size ) * self .survey .nD * 8. * 1e-6 / self .max_chunk_size )
119120 ))
121+
120122 solution_type = self ._fieldType + "Solution" # the thing we solved for
121123
122124 if self .store_sensitivities == "disk" :
@@ -152,60 +154,67 @@ def compute_J(self, f=None, Ainv=None):
152154 self .field_derivs = dask .compute (field_derivs )[0 ]
153155
154156 f = dask .delayed (f )
155- field_derivatives = {}
157+ field_derivatives = None
158+ batch_map = {}
156159
157160 for tInd , dt in tqdm (zip (reversed (range (self .nT )), reversed (self .time_steps ))):
158161
159162 AdiagTinv = Ainv [dt ]
160163 Asubdiag = self .getAsubdiag (tInd )
161164 d_count = 0
165+ block_count = 0
162166 field_deriv_blocks = []
163167 j_row_blocks = []
164-
168+ count = 0
169+ batch_block = []
170+ batch_indices = []
171+ batch_count = 0
165172 for isrc , src in enumerate (self .survey .source_list ):
166173 field_blocks = []
167174 n_data = self .field_derivs [tInd + 1 ][isrc ][0 ].shape [1 ]
168175 n_blocks = int (np .ceil ((m_size * n_data ) * 8. * 1e-6 / 128. ))
169176 sub_blocks = np .array_split (np .arange (n_data ), n_blocks )
170177
171- for block_ind in sub_blocks :
172- if isrc not in field_derivatives :
173- ATinv_df_duT_v = (
174- AdiagTinv * self .field_derivs [tInd + 1 ][isrc ][0 ][:, block_ind ].toarray ()
175- )
176- else :
177- ATinv_df_duT_v = AdiagTinv * np .asarray (field_derivatives [isrc ][:, block_ind ])
178+ for i_block , block_ind in enumerate (sub_blocks ):
178179
179- if self .store_sensitivities == "disk" :
180- partial_derivs .set_orthogonal_selection (
181- (slice (None ), slice (d_count , d_count + len (block_ind ))),
182- ATinv_df_duT_v
183- )
180+ if field_derivatives is None :
181+ batch_block .append (self .field_derivs [tInd + 1 ][isrc ][0 ][:, block_ind ].toarray ())
182+ batch_map [isrc , i_block ] = (batch_count , count )
184183 else :
185- partial_derivs [:, d_count : d_count + len (block_ind )] = ATinv_df_duT_v
186-
187- field_blocks .append (
188- dask .array .from_delayed (
189- delayed (parallel_field_deriv , pure = True )(
190- partial_derivs [:, d_count : d_count + len (block_ind )], Asubdiag ,
191- self .field_derivs [tInd ][isrc ][0 ][:, block_ind ]
192- ),
193- shape = (Asubdiag .shape [0 ], len (block_ind )),
194- dtype = np .float64
195- )
196- )
197- j_row_blocks .append (dask .array .from_delayed (
198- delayed (parallel_block_compute , pure = True )(
199- self , f , src , partial_derivs [:, d_count : d_count + len (block_ind )],
200- tInd , solution_type , d_count , Jmatrix , self .field_derivs [tInd + 1 ][isrc ][1 ][block_ind , :]
201- ),
202- shape = (len (block_ind ), m_size ),
203- dtype = np .float32
204- ))
205- d_count += len (block_ind )
206-
207- field_deriv_blocks .append (dask .array .hstack (field_blocks ))
184+ i_file , i_block = batch_map [isrc , i_block ]
185+ batch_block .append (field_derivatives [i_file ][:, i_block :i_block + len (block_ind )])
186+
187+ batch_indices .append ((isrc , block_ind ))
188+ block_count += 1
208189
190+ if block_count >= cpu_count ():
191+ f_blocks , j_blocks = process_blocks (
192+ self , AdiagTinv , d_count , batch_block , batch_indices , Asubdiag , f , tInd ,
193+ solution_type , Jmatrix
194+ )
195+ field_deriv_blocks .append (dask .array .hstack (f_blocks ))
196+ j_row_blocks .append (j_blocks )
197+
198+ batch_block , batch_indices = [], []
199+ block_count = 0
200+ batch_count += 1
201+ d_count += count
202+ count = 0
203+
204+ count += len (block_ind )
205+ # if isrc not in field_derivatives:
206+ # ATinv_df_duT_v = (
207+ # AdiagTinv * self.field_derivs[tInd + 1][isrc][0][:, block_ind].toarray()
208+ # )
209+ # else:
210+ # ATinv_df_duT_v = AdiagTinv * np.asarray(field_derivatives[isrc][:, block_ind])
211+
212+ f_blocks , j_blocks = process_blocks (
213+ self , AdiagTinv , d_count , batch_block , batch_indices , Asubdiag , f , tInd ,
214+ solution_type , Jmatrix
215+ )
216+ field_deriv_blocks .append (dask .array .hstack (f_blocks ))
217+ j_row_blocks .append (j_blocks )
209218 del field_derivatives
210219
211220 if self .store_sensitivities == "disk" :
@@ -224,8 +233,6 @@ def compute_J(self, f=None, Ainv=None):
224233 dask .compute (j_row_blocks )
225234 field_derivatives = dask .compute (field_deriv_blocks )[0 ]
226235
227- field_derivatives = {isrc : elem for isrc , elem in enumerate (field_derivatives )}
228-
229236 for A in Ainv .values ():
230237 A .clean ()
231238
@@ -238,6 +245,46 @@ def compute_J(self, f=None, Ainv=None):
238245Sim .compute_J = compute_J
239246
240247
248+ def process_blocks (
249+ self , AdiagTinv , d_count , batch_block , batch_indices , Asubdiag , f , tInd ,
250+ solution_type , Jmatrix
251+ ):
252+ ATinv_df_duT_v = AdiagTinv * np .hstack (batch_block )
253+ field_blocks = []
254+ j_row_blocks = []
255+ count = 0
256+ for block , indices in zip (batch_block , batch_indices ):
257+ block_size = block .shape [1 ]
258+ field_blocks .append (
259+ dask .array .from_delayed (
260+ delayed (parallel_field_deriv , pure = True )(
261+ ATinv_df_duT_v [:, count : count + block_size ], Asubdiag ,
262+ self .field_derivs [tInd ][indices [0 ]][0 ][:, indices [1 ]]
263+ ),
264+ shape = (Asubdiag .shape [0 ], block_size ),
265+ dtype = np .float64
266+ )
267+ )
268+ j_row_blocks .append (dask .array .from_delayed (
269+ delayed (parallel_block_compute , pure = True )(
270+ self , f ,
271+ self .survey .source_list [indices [0 ]],
272+ ATinv_df_duT_v [:, count : count + block_size ],
273+ tInd ,
274+ solution_type ,
275+ d_count ,
276+ Jmatrix ,
277+ self .field_derivs [tInd + 1 ][indices [0 ]][1 ][indices [1 ], :]
278+ ),
279+ shape = (block_size , Jmatrix .shape [1 ]),
280+ dtype = np .float32
281+ ))
282+ count += block_size
283+ d_count += block_size
284+
285+ return field_blocks , j_row_blocks
286+
287+
241288def block_deriv (simulation , src , tInd , f , block_size ):
242289 src_field_derivs = []
243290 j_initial = []
0 commit comments