@@ -66,34 +66,32 @@ def __init__(
6666
6767 if pool_fraction is None and replacement_prob is None :
6868 self ._cache_all = True
69- self ._pool_fraction = None
70- self ._pool_size = None
71- self ._replacement_prob = None
7269 else :
7370 if pool_fraction is None :
7471 raise ValueError ("pool_fraction must be provided if replacement_prob is provided." )
7572 if replacement_prob is None :
7673 raise ValueError ("replacement_prob must be provided if pool_fraction is provided." )
77- if not (0 < pool_fraction <= 1 ):
78- raise ValueError ("pool_fraction must be in (0, 1]." )
79- self ._pool_fraction = pool_fraction
80- self ._pool_size = math .ceil (pool_fraction * self .n_source_dists )
81- self ._replacement_prob = replacement_prob
82- if pool_fraction == 1.0 :
83- self ._cache_all = True
84-
85- self ._pool_usage_count = {}
74+ # Compute pool size from fraction
75+ if not (0 < pool_fraction <= 1 ):
76+ raise ValueError ("pool_fraction must be in (0, 1]." )
77+ self ._pool_fraction = pool_fraction
78+ self ._pool_size = math .ceil (pool_fraction * self .n_source_dists )
79+ self ._replacement_prob = replacement_prob
80+ self ._pool_usage_count = np .zeros (self .n_source_dists , dtype = int )
8681 self ._initialized = False
8782 self ._src_idx_pool = None
8883
84+ if pool_fraction == 1.0 :
85+ self ._cache_all = True
86+
8987 self ._lock = nullcontext () if self ._cache_all else threading .RLock ()
9088 self ._executor = None
9189 self ._pending_replacements = {}
9290 if not self ._cache_all :
9391 self ._executor = ThreadPoolExecutor (max_workers = 2 ) # TODO: avoid magic numbers
9492 self ._pending_replacements : dict [int , dict [str , Any ]] = {}
9593
96-
94+
9795
9896 def init_sampler (self , rng ) -> None :
9997 if self ._initialized :
@@ -104,11 +102,10 @@ def init_sampler(self, rng) -> None:
104102 return None
105103
106104 def _init_src_idx_pool (self , rng ) -> None :
107- src_indices = np .array (list (self ._data .data .src_data .keys ()))
108105 if self ._cache_all :
109- self ._src_idx_pool = src_indices
106+ self ._src_idx_pool = np . arange ( self . n_source_dists )
110107 else :
111- self ._src_idx_pool = rng .choice (src_indices , size = self ._pool_size , replace = False )
108+ self ._src_idx_pool = rng .choice (self . n_source_dists , size = self ._pool_size , replace = False )
112109 return None
113110
114111
@@ -126,32 +123,24 @@ def sample(self, rng) -> dict[str, Any]:
126123 """
127124 source_dist_idx = self ._sample_source_dist_idx (rng )
128125 target_dist_idx = self ._sample_target_dist_idx (rng , source_dist_idx )
126+ print (f"sampled source dist idx: { source_dist_idx } and target dist idx: { target_dist_idx } " )
129127 source_batch = self ._sample_source_cells (rng , source_dist_idx )
128+ print (f"sampled source batch: { source_batch .shape } " )
130129 target_batch = self ._sample_target_cells (rng , source_dist_idx , target_dist_idx )
131-
132- flat_condition = self ._data .data .conditions [target_dist_idx ]
133-
134- if hasattr (self ._data , 'annotation' ) and self ._data .annotation .condition_structure :
135- condition = {}
136- max_combination_length = getattr (self ._data , 'max_combination_length' , 1 )
137- for cov_name , (start , end ) in self ._data .annotation .condition_structure .items ():
138- condition [cov_name ] = flat_condition [start :end ].reshape (1 , max_combination_length , - 1 )
139- else :
140- condition = flat_condition
141-
130+ print (f"sampled target batch: { target_batch .shape } " )
142131 res = {
143132 "src_cell_data" : source_batch ,
144- "tgt_cell_data" : target_batch ,
145- "condition" : condition
133+ "tgt_cell_data" : target_batch
146134 }
135+ res ["condition" ] = self ._data .data .conditions [target_dist_idx ]
147136 return res
148137
149138
150139 def _load_targets_parallel (self , tgt_indices ):
151140 """Load multiple target distributions in parallel."""
152141 def _load_tgt (j : int ):
153142 return j , self ._data .data .tgt_data [j ][...]
154-
143+
155144 max_workers = min (32 , (os .cpu_count () or 4 )) # TODO: avoid magic numbers
156145 with ThreadPoolExecutor (max_workers = max_workers ) as ex :
157146 results = list (ex .map (_load_tgt , tgt_indices ))
@@ -160,12 +149,12 @@ def _load_tgt(j: int):
160149 def _init_cache_pool_elements (self ) -> None :
161150 with self ._lock :
162151 self ._cached_srcs = {i : self ._data .data .src_data [i ][...] for i in self ._src_idx_pool }
163-
152+
164153 tgt_indices = sorted ({int (j ) for i in self ._src_idx_pool for j in self ._data .data .src_to_tgt_dist_map [i ]})
165-
154+
166155 with self ._lock :
167156 self ._cached_tgts = self ._load_targets_parallel (tgt_indices )
168-
157+
169158 return None
170159
171160
@@ -189,7 +178,7 @@ def _sample_source_dist_idx(self, rng) -> int:
189178
190179 def _sample_source_dist_idx_in_memory (self , rng ) -> int :
191180 source_idx = rng .choice (sorted (self ._cached_srcs .keys ()))
192- self ._pool_usage_count [source_idx ] = self . _pool_usage_count . get ( source_idx , 0 ) + 1
181+ self ._pool_usage_count [source_idx ] += 1
193182 return source_idx
194183
195184 def _sample_source_dist_idx_in_pool (self , rng ) -> int :
@@ -199,70 +188,54 @@ def _sample_source_dist_idx_in_pool(self, rng) -> int:
199188 source_idx = rng .choice (sorted (self ._cached_srcs .keys ()))
200189
201190 # Increment usage count for monitoring
202- self ._pool_usage_count [source_idx ] = self . _pool_usage_count . get ( source_idx , 0 ) + 1
191+ self ._pool_usage_count [source_idx ] += 1
203192
204193 # Gradually replace elements based on replacement probability (schedule only)
205194 if rng .random () < self ._replacement_prob :
206195 self ._schedule_replacement (rng )
207-
196+
208197 return source_idx
209198
210199 def _schedule_replacement (self , rng ):
211200 if self ._cache_all :
212201 return # No replacement if everything is cached
213-
214- # Get usage counts for indices in the pool
215- pool_indices = self ._src_idx_pool .tolist ()
216- usage_counts = np .array ([self ._pool_usage_count .get (idx , 0 ) for idx in pool_indices ])
217-
218- if len (usage_counts ) == 0 :
219- return
220-
221- max_usage = usage_counts .max ()
222- most_used_weight = (usage_counts == max_usage ).astype (float )
202+ # weights same as previous logic
203+ most_used_weight = (self ._pool_usage_count == self ._pool_usage_count .max ()).astype (float )
223204 if most_used_weight .sum () == 0 :
224205 return
225206 most_used_weight /= most_used_weight .sum ()
226- replaced_pool_slot = rng .choice (len (pool_indices ), p = most_used_weight )
227- replaced_pool_idx = pool_indices [replaced_pool_slot ]
207+ replaced_pool_idx = rng .choice (self .n_source_dists , p = most_used_weight )
228208
229209 with self ._lock :
230- # If there's already a pending replacement for this pool slot, skip
231- if replaced_pool_slot in self . _pending_replacements :
210+ pool_set = set ( self . _src_idx_pool . tolist ())
211+ if replaced_pool_idx not in pool_set :
232212 return
213+ in_pool_idx = int (np .where (self ._src_idx_pool == replaced_pool_idx )[0 ][0 ])
233214
234- # Find all available source indices (not currently in pool)
235- all_src_indices = list (self ._data .data .src_data .keys ())
236- pool_set = set (pool_indices )
237- available_indices = [idx for idx in all_src_indices if idx not in pool_set ]
238-
239- if not available_indices :
215+ # If there's already a pending replacement for this pool slot, skip
216+ if in_pool_idx in self ._pending_replacements :
240217 return
241218
242- # Get usage counts for available indices
243- available_usage = np .array ([self ._pool_usage_count .get (idx , 0 ) for idx in available_indices ])
244- min_usage = available_usage .min ()
245- least_used_weight = (available_usage == min_usage ).astype (float )
219+ least_used_weight = (self ._pool_usage_count == self ._pool_usage_count .min ()).astype (float )
246220 if least_used_weight .sum () == 0 :
247221 return
248222 least_used_weight /= least_used_weight .sum ()
249- new_idx_position = rng .choice (len (available_indices ), p = least_used_weight )
250- new_pool_idx = available_indices [new_idx_position ]
223+ new_pool_idx = int (rng .choice (self .n_source_dists , p = least_used_weight ))
251224
252225 # Kick off background load for new indices
253226 fut : Future = self ._executor .submit (self ._load_new_cache , new_pool_idx )
254- self ._pending_replacements [replaced_pool_slot ] = {
227+ self ._pending_replacements [in_pool_idx ] = {
255228 "old" : replaced_pool_idx ,
256229 "new" : new_pool_idx ,
257230 "future" : fut ,
258231 }
259- print (f"scheduled replacement of { replaced_pool_idx } with { new_pool_idx } (slot { replaced_pool_slot } )" )
232+ print (f"scheduled replacement of { replaced_pool_idx } with { new_pool_idx } (slot { in_pool_idx } )" )
260233
261234 def _load_targets_parallel (self , tgt_indices ):
262235 """Load multiple target distributions in parallel."""
263236 def _load_tgt (j : int ):
264237 return j , self ._data .data .tgt_data [j ][...]
265-
238+
266239 max_workers = min (32 , (os .cpu_count () or 4 )) # TODO: avoid magic numbers
267240 with ThreadPoolExecutor (max_workers = max_workers ) as ex :
268241 results = list (ex .map (_load_tgt , tgt_indices ))
0 commit comments