@@ -105,6 +105,8 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields):
105105
106106Sim .dpred = dask_dpred
107107Sim .field_derivs = None
108+ Sim .j_initialzer = None
109+
108110
109111def compute_J (self , f = None , Ainv = None ):
110112
@@ -115,50 +117,33 @@ def compute_J(self, f=None, Ainv=None):
115117 row_chunks = int (np .ceil (
116118 float (self .survey .nD ) / np .ceil (float (m_size ) * self .survey .nD * 8. * 1e-6 / self .max_chunk_size )
117119 ))
120+ solution_type = self ._fieldType + "Solution" # the thing we solved for
118121
119122 if self .store_sensitivities == "disk" :
120- self . J_initializer = zarr .open (
121- self .sensitivity_path + f"J_initializer .zarr" ,
123+ Jmatrix = zarr .open (
124+ self .sensitivity_path + f"J .zarr" ,
122125 mode = 'w' ,
123126 shape = (self .survey .nD , m_size ),
124127 chunks = (row_chunks , m_size )
125- )
128+ )# + J_initializer
126129 else :
127- self .J_initializer = np .zeros ((self .survey .nD , m_size ), dtype = np .float32 )
128- solution_type = self ._fieldType + "Solution" # the thing we solved for
130+ Jmatrix = np .zeros ((self .survey .nD , m_size ), dtype = np .float32 )
129131
130132 if self .field_derivs is None :
131-
132- # print("Start loop for field derivs")
133133 block_size = len (f [self .survey .source_list [0 ], solution_type , 0 ])
134-
135134 field_derivs = []
135+
136136 for tInd in range (self .nT + 1 ):
137137 d_count = 0
138138 df_duT_v = []
139139 for i_s , src in enumerate (self .survey .source_list ):
140- src_field_derivs = delayed (block_deriv , pure = True )(self , src , tInd , f , block_size , d_count )
140+ src_field_derivs = delayed (block_deriv , pure = True )(self , src , tInd , f , block_size )
141141 df_duT_v += [src_field_derivs ]
142142 d_count += np .sum ([rx .nD for rx in src .receiver_list ])
143143
144144 field_derivs += [df_duT_v ]
145- # print("Dask loop field derivs")
146- # tc = time()
147-
148145 self .field_derivs = dask .compute (field_derivs )[0 ]
149- # print(f"Done in {time() - tc} seconds")
150146
151- if self .store_sensitivities == "disk" :
152- Jmatrix = zarr .open (
153- self .sensitivity_path + f"J.zarr" ,
154- mode = 'w' ,
155- shape = (self .survey .nD , m_size ),
156- chunks = (row_chunks , m_size )
157- ) + self .J_initializer
158- else :
159- Jmatrix = dask .delayed (np .zeros ((self .survey .nD , m_size ), dtype = np .float32 ) + self .J_initializer )
160-
161- # ATinv_df_duT_v = {}
162147 f = dask .delayed (f )
163148 field_derivs_t = {}
164149
@@ -167,59 +152,76 @@ def compute_J(self, f=None, Ainv=None):
167152 AdiagTinv = Ainv [dt ]
168153 Asubdiag = self .getAsubdiag (tInd )
169154 d_count = 0
155+ field_deriv_blocks = []
170156 row_blocks = []
171157
172- # tc_loop = time()
173- # print(f"Loop sources for {tInd}")
174158 for isrc , src in enumerate (self .survey .source_list ):
175159 source_blocks = []
176- # for block in range(len(self.field_derivs[tInd][isrc])):
177- if isrc not in field_derivs_t :
178- ATinv_df_duT_v = dask .delayed (AdiagTinv * self .field_derivs [tInd + 1 ][isrc ].toarray ())
179- else :
180- ATinv_df_duT_v = dask .delayed (AdiagTinv * field_derivs_t [isrc ])
181-
182- n_data = self .field_derivs [tInd + 1 ][isrc ].shape [1 ]
160+ n_data = self .field_derivs [tInd + 1 ][isrc ][0 ].shape [1 ]
183161 n_blocks = int (np .ceil ((m_size * n_data ) * 8. * 1e-6 / 128. ))
184- ind_col = np .array_split (np .arange (n_data ), n_blocks )
162+ sub_blocks = np .array_split (np .arange (n_data ), n_blocks )
163+
164+ for block_ind in sub_blocks :
165+ if isrc not in field_derivs_t :
166+ ATinv_df_duT_v = (
167+ AdiagTinv * self .field_derivs [tInd + 1 ][isrc ][0 ][:, block_ind ].toarray ()
168+ )
169+ else :
170+ ATinv_df_duT_v = AdiagTinv * np .asarray (field_derivs_t [isrc ][:, block_ind ])
171+
172+ delayed_J_block = delayed (parallel_block_compute , pure = True )(
173+ self , f , src , ATinv_df_duT_v ,
174+ tInd , solution_type , d_count , Jmatrix , self .field_derivs [tInd + 1 ][isrc ][1 ][block_ind , :]
175+ )
176+
177+ delayed_field_block = delayed (parallel_field_deriv , pure = True )(
178+ ATinv_df_duT_v , Asubdiag , self .field_derivs [tInd ][isrc ][0 ][:, block_ind ]
179+ )
185180
186- for col_block in ind_col :
187181 source_blocks .append (
188182 dask .array .from_delayed (
189- delayed (parallel_block_compute , pure = True )(
190- self , f , src , ATinv_df_duT_v , d_count ,
191- col_block , tInd , solution_type , Jmatrix , Asubdiag ,
192- self .field_derivs [tInd ][isrc ]
193- ),
194- shape = self .field_derivs [tInd + 1 ][isrc ].shape ,
183+ delayed_field_block ,
184+ shape = (Asubdiag .shape [0 ], len (block_ind )),
195185 dtype = np .float32
196186 )
197187 )
188+
189+ row_blocks .append (dask .array .from_delayed (
190+ delayed_J_block ,
191+ shape = (len (block_ind ), m_size ),
192+ dtype = np .float32
193+ ))
198194 # print(f"Appending block {isrc} in {time() - tc} seconds")
199- d_count += len (col_block )
195+ d_count += len (block_ind )
196+
197+ field_deriv_blocks .append (dask .array .hstack (source_blocks ))
198+
199+ if self .store_sensitivities == "disk" :
200+ Jmatrix .set_orthogonal_selection (
201+ (np .arange (self .survey .nD ), slice (None )),
202+ Jmatrix + dask .array .vstack (row_blocks ).astype (np .float32 )
203+ )
204+ else :
205+ dask .compute (row_blocks )
200206
201- row_blocks .append (dask .array .hstack (source_blocks ))
202- # print(f"Done in {time() - tc_loop} seconds")
203- # tc = time()
204- # print(f"Compute field derivs for {tInd}")
205207 del field_derivs_t
206- field_derivs_t = {isrc : elem for isrc , elem in enumerate (dask .compute (row_blocks )[0 ])}
207- # print(f"Done in {time() - tc} seconds")
208+ field_derivs_t = {isrc : elem for isrc , elem in enumerate (dask .compute (field_deriv_blocks )[0 ])}
208209
209210 for A in Ainv .values ():
210211 A .clean ()
211212
212213 if self .store_sensitivities == "disk" :
213214 del Jmatrix
214215 return array .from_zarr (self .sensitivity_path + f"J.zarr" )
215- else :
216- return Jmatrix . compute ()
216+
217+ return Jmatrix
217218
218219Sim .compute_J = compute_J
219220
220221
221- def block_deriv (simulation , src , tInd , f , block_size , d_count ):
222- src_field_derivs = None
222+ def block_deriv (simulation , src , tInd , f , block_size ):
223+ src_field_derivs = []
224+ j_initial = []
223225 for rx in src .receiver_list :
224226
225227 v = sp .eye (rx .nD , dtype = float )
@@ -235,37 +237,33 @@ def block_deriv(simulation, src, tInd, f, block_size, d_count):
235237 PT_v [tInd * block_size :(tInd + 1 ) * block_size , :],
236238 adjoint = True ,
237239 )
238-
239- if not isinstance (cur [1 ], Zero ):
240- simulation .J_initializer [d_count :d_count + rx .nD , :] += cur [1 ].T
241-
242- if src_field_derivs is None :
243- src_field_derivs = cur [0 ]
244- else :
245- src_field_derivs += cur [0 ]
240+ src_field_derivs .append (cur [0 ])
241+ j_initial .append (cur [1 ].T )
246242
247243 # n_blocks = int(np.ceil(np.prod(src_field_derivs.shape) * 8. * 1e-6 / 128.))
248244 # ind_col = np.array_split(np.arange(src_field_derivs.shape[1]), col_blocks)
249245 # return [src_field_derivs[:, ind] for ind in ind_col]
250- return src_field_derivs
246+ return sp . hstack ( src_field_derivs ), sp . vstack ( j_initial )
251247
252- def parallel_block_compute (simulation , f , src , ATinv_df_duT_v , d_count , col_block , tInd , solution_type , Jmatrix , Asubdiag , field_derivs ):
253- field_derivs_t = np .asarray (
254- field_derivs [:, col_block ]
255- - Asubdiag .T * ATinv_df_duT_v [:, col_block ]
256- )
257248
249+ def parallel_field_deriv (ATinv_df_duT_v , Asubdiag , field_derivs ):
250+ return field_derivs - Asubdiag .T * ATinv_df_duT_v
251+
252+
253+ def parallel_block_compute (simulation , f , src , ATinv_df_duT_v , tInd , solution_type , d_count , Jmatrix , j_initial ):
258254 dAsubdiagT_dm_v = simulation .getAsubdiagDeriv (
259- tInd , f [src , solution_type , tInd ], ATinv_df_duT_v [:, col_block ] , adjoint = True
255+ tInd , f [src , solution_type , tInd ], ATinv_df_duT_v , adjoint = True
260256 )
261257
262258 dRHST_dm_v = simulation .getRHSDeriv (
263- tInd + 1 , src , ATinv_df_duT_v [:, col_block ] , adjoint = True
259+ tInd + 1 , src , ATinv_df_duT_v , adjoint = True
264260 )
265261 un_src = f [src , solution_type , tInd + 1 ]
266262 dAT_dm_v = simulation .getAdiagDeriv (
267- tInd , un_src , ATinv_df_duT_v [:, col_block ] , adjoint = True
263+ tInd , un_src , ATinv_df_duT_v , adjoint = True
268264 )
269- Jmatrix [d_count :d_count + dAT_dm_v .shape [1 ], :] += (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T
270265
271- return field_derivs_t
266+ if simulation .store_sensitivities == "disk" :
267+ return (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T + j_initial
268+
269+ Jmatrix [d_count :d_count + dAT_dm_v .shape [1 ], :] += (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T + j_initial
0 commit comments