@@ -125,9 +125,16 @@ def compute_J(self, f=None, Ainv=None):
125125 mode = 'w' ,
126126 shape = (self .survey .nD , m_size ),
127127 chunks = (row_chunks , m_size )
128- )# + J_initializer
128+ )
129+ partial_derivs = zarr .open (
130+ self .sensitivity_path + f"partials.zarr" ,
131+ mode = 'w' ,
132+ shape = (self .getAsubdiag (0 ).shape [0 ], self .survey .nD ),
133+ chunks = (self .getAsubdiag (0 ).shape [0 ], row_chunks )
134+ )
129135 else :
130136 Jmatrix = np .zeros ((self .survey .nD , m_size ), dtype = np .float32 )
137+ partial_derivs = np .zeros ((self .getAsubdiag (0 ).shape [0 ], self .survey .nD ), dtype = np .float32 )
131138
132139 if self .field_derivs is None :
133140 block_size = len (f [self .survey .source_list [0 ], solution_type , 0 ])
@@ -145,67 +152,79 @@ def compute_J(self, f=None, Ainv=None):
145152 self .field_derivs = dask .compute (field_derivs )[0 ]
146153
147154 f = dask .delayed (f )
148- field_derivs_t = {}
155+ field_derivatives = {}
149156
150157 for tInd , dt in tqdm (zip (reversed (range (self .nT )), reversed (self .time_steps ))):
151158
152159 AdiagTinv = Ainv [dt ]
153160 Asubdiag = self .getAsubdiag (tInd )
154161 d_count = 0
155162 field_deriv_blocks = []
156- row_blocks = []
163+ j_row_blocks = []
157164
158165 for isrc , src in enumerate (self .survey .source_list ):
159- source_blocks = []
166+ field_blocks = []
160167 n_data = self .field_derivs [tInd + 1 ][isrc ][0 ].shape [1 ]
161168 n_blocks = int (np .ceil ((m_size * n_data ) * 8. * 1e-6 / 128. ))
162169 sub_blocks = np .array_split (np .arange (n_data ), n_blocks )
163170
164171 for block_ind in sub_blocks :
165- if isrc not in field_derivs_t :
172+ if isrc not in field_derivatives :
166173 ATinv_df_duT_v = (
167174 AdiagTinv * self .field_derivs [tInd + 1 ][isrc ][0 ][:, block_ind ].toarray ()
168175 )
169176 else :
170- ATinv_df_duT_v = AdiagTinv * np .asarray (field_derivs_t [isrc ][:, block_ind ])
171-
172- delayed_J_block = delayed (parallel_block_compute , pure = True )(
173- self , f , src , ATinv_df_duT_v ,
174- tInd , solution_type , d_count , Jmatrix , self .field_derivs [tInd + 1 ][isrc ][1 ][block_ind , :]
175- )
177+ ATinv_df_duT_v = AdiagTinv * np .asarray (field_derivatives [isrc ][:, block_ind ])
176178
177- delayed_field_block = delayed (parallel_field_deriv , pure = True )(
178- ATinv_df_duT_v , Asubdiag , self .field_derivs [tInd ][isrc ][0 ][:, block_ind ]
179- )
179+ if self .store_sensitivities == "disk" :
180+ partial_derivs .set_orthogonal_selection (
181+ (slice (None ), slice (d_count , d_count + len (block_ind ))),
182+ ATinv_df_duT_v
183+ )
184+ else :
185+ partial_derivs [:, d_count : d_count + len (block_ind )] = ATinv_df_duT_v
180186
181- source_blocks .append (
187+ field_blocks .append (
182188 dask .array .from_delayed (
183- delayed_field_block ,
189+ delayed (parallel_field_deriv , pure = True )(
190+ partial_derivs [:, d_count : d_count + len (block_ind )], Asubdiag ,
191+ self .field_derivs [tInd ][isrc ][0 ][:, block_ind ]
192+ ),
184193 shape = (Asubdiag .shape [0 ], len (block_ind )),
185- dtype = np .float32
194+ dtype = np .float64
186195 )
187196 )
188-
189- row_blocks .append (dask .array .from_delayed (
190- delayed_J_block ,
197+ j_row_blocks .append (dask .array .from_delayed (
198+ delayed (parallel_block_compute , pure = True )(
199+ self , f , src , partial_derivs [:, d_count : d_count + len (block_ind )],
200+ tInd , solution_type , d_count , Jmatrix , self .field_derivs [tInd + 1 ][isrc ][1 ][block_ind , :]
201+ ),
191202 shape = (len (block_ind ), m_size ),
192203 dtype = np .float32
193204 ))
194- # print(f"Appending block {isrc} in {time() - tc} seconds")
195205 d_count += len (block_ind )
196206
197- field_deriv_blocks .append (dask .array .hstack (source_blocks ))
207+ field_deriv_blocks .append (dask .array .hstack (field_blocks ))
208+
209+ del field_derivatives
198210
199211 if self .store_sensitivities == "disk" :
200212 Jmatrix .set_orthogonal_selection (
201213 (np .arange (self .survey .nD ), slice (None )),
202- Jmatrix + dask .array .vstack (row_blocks ).astype (np .float32 )
214+ Jmatrix + dask .array .vstack (j_row_blocks ).astype (np .float32 )
203215 )
216+ field_derivatives = [
217+ dask .array .to_zarr (
218+ field_deriv_blocks [i ], self .sensitivity_path + f"field_derivs_{ i } .zarr" ,
219+ overwrite = True ,
220+ return_stored = True ,
221+ ) for i in range (len (field_deriv_blocks ))
222+ ]
204223 else :
205- dask .compute (row_blocks )
224+ dask .compute (j_row_blocks )
225+ field_derivatives = dask .compute (field_deriv_blocks )[0 ]
206226
207- del field_derivs_t
208- field_derivs_t = {isrc : elem for isrc , elem in enumerate (dask .compute (field_deriv_blocks )[0 ])}
227+ field_derivatives = {isrc : elem for isrc , elem in enumerate (field_derivatives )}
209228
210229 for A in Ainv .values ():
211230 A .clean ()
0 commit comments