Skip to content

Commit a510c1a

Browse files
committed
Revert "merging to main"
This reverts commit 58383b1.
1 parent f1c923f commit a510c1a

7 files changed

Lines changed: 72 additions & 162 deletions

File tree

src/scaleflow/data/_data.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ class GroupedDistributionAnnotation:
132132
tgt_dist_keys: list[str]
133133
src_dist_keys: list[str]
134134
dist_flag_key: str
135-
condition_structure: dict[str, tuple[int, int]] | None = None # Maps covariate name to (start, end) indices in flat array
136135

137136
@classmethod
138137
def read_zarr(
@@ -294,9 +293,8 @@ def split_by_dist_df(self, dist_df: pd.DataFrame, column: str) -> dict[str, Grou
294293
.to_dict()
295294
)
296295
src_data = {int(k): self.data.src_data[k] for k in src_tgt_dist_map.keys()}
297-
tgt_indices = {int(j) for tgt_list in src_tgt_dist_map.values() for j in tgt_list}
298-
tgt_data = {int(k): self.data.tgt_data[k] for k in tgt_indices}
299-
conditions = {int(k): self.data.conditions[k] for k in tgt_indices}
296+
tgt_data = {int(k): self.data.tgt_data[k] for k in src_tgt_dist_map.keys()}
297+
conditions = {int(k): self.data.conditions[k] for k in src_tgt_dist_map.keys()}
300298
split_data[value] = GroupedDistributionData(
301299
src_to_tgt_dist_map=src_tgt_dist_map,
302300
src_data=src_data,

src/scaleflow/data/_dataloader.py

Lines changed: 38 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -66,34 +66,32 @@ def __init__(
6666

6767
if pool_fraction is None and replacement_prob is None:
6868
self._cache_all = True
69-
self._pool_fraction = None
70-
self._pool_size = None
71-
self._replacement_prob = None
7269
else:
7370
if pool_fraction is None:
7471
raise ValueError("pool_fraction must be provided if replacement_prob is provided.")
7572
if replacement_prob is None:
7673
raise ValueError("replacement_prob must be provided if pool_fraction is provided.")
77-
if not (0 < pool_fraction <= 1):
78-
raise ValueError("pool_fraction must be in (0, 1].")
79-
self._pool_fraction = pool_fraction
80-
self._pool_size = math.ceil(pool_fraction * self.n_source_dists)
81-
self._replacement_prob = replacement_prob
82-
if pool_fraction == 1.0:
83-
self._cache_all = True
84-
85-
self._pool_usage_count = {}
74+
# Compute pool size from fraction
75+
if not (0 < pool_fraction <= 1):
76+
raise ValueError("pool_fraction must be in (0, 1].")
77+
self._pool_fraction = pool_fraction
78+
self._pool_size = math.ceil(pool_fraction * self.n_source_dists)
79+
self._replacement_prob = replacement_prob
80+
self._pool_usage_count = np.zeros(self.n_source_dists, dtype=int)
8681
self._initialized = False
8782
self._src_idx_pool = None
8883

84+
if pool_fraction == 1.0:
85+
self._cache_all = True
86+
8987
self._lock = nullcontext() if self._cache_all else threading.RLock()
9088
self._executor = None
9189
self._pending_replacements = {}
9290
if not self._cache_all:
9391
self._executor = ThreadPoolExecutor(max_workers=2) # TODO: avoid magic numbers
9492
self._pending_replacements: dict[int, dict[str, Any]] = {}
9593

96-
94+
9795

9896
def init_sampler(self, rng) -> None:
9997
if self._initialized:
@@ -104,11 +102,10 @@ def init_sampler(self, rng) -> None:
104102
return None
105103

106104
def _init_src_idx_pool(self, rng) -> None:
107-
src_indices = np.array(list(self._data.data.src_data.keys()))
108105
if self._cache_all:
109-
self._src_idx_pool = src_indices
106+
self._src_idx_pool = np.arange(self.n_source_dists)
110107
else:
111-
self._src_idx_pool = rng.choice(src_indices, size=self._pool_size, replace=False)
108+
self._src_idx_pool = rng.choice(self.n_source_dists, size=self._pool_size, replace=False)
112109
return None
113110

114111

@@ -126,32 +123,24 @@ def sample(self, rng) -> dict[str, Any]:
126123
"""
127124
source_dist_idx = self._sample_source_dist_idx(rng)
128125
target_dist_idx = self._sample_target_dist_idx(rng, source_dist_idx)
126+
print(f"sampled source dist idx: {source_dist_idx} and target dist idx: {target_dist_idx}")
129127
source_batch = self._sample_source_cells(rng, source_dist_idx)
128+
print(f"sampled source batch: {source_batch.shape}")
130129
target_batch = self._sample_target_cells(rng, source_dist_idx, target_dist_idx)
131-
132-
flat_condition = self._data.data.conditions[target_dist_idx]
133-
134-
if hasattr(self._data, 'annotation') and self._data.annotation.condition_structure:
135-
condition = {}
136-
max_combination_length = getattr(self._data, 'max_combination_length', 1)
137-
for cov_name, (start, end) in self._data.annotation.condition_structure.items():
138-
condition[cov_name] = flat_condition[start:end].reshape(1, max_combination_length, -1)
139-
else:
140-
condition = flat_condition
141-
130+
print(f"sampled target batch: {target_batch.shape}")
142131
res = {
143132
"src_cell_data": source_batch,
144-
"tgt_cell_data": target_batch,
145-
"condition": condition
133+
"tgt_cell_data": target_batch
146134
}
135+
res["condition"] = self._data.data.conditions[target_dist_idx]
147136
return res
148137

149138

150139
def _load_targets_parallel(self, tgt_indices):
151140
"""Load multiple target distributions in parallel."""
152141
def _load_tgt(j: int):
153142
return j, self._data.data.tgt_data[j][...]
154-
143+
155144
max_workers = min(32, (os.cpu_count() or 4)) # TODO: avoid magic numbers
156145
with ThreadPoolExecutor(max_workers=max_workers) as ex:
157146
results = list(ex.map(_load_tgt, tgt_indices))
@@ -160,12 +149,12 @@ def _load_tgt(j: int):
160149
def _init_cache_pool_elements(self) -> None:
161150
with self._lock:
162151
self._cached_srcs = {i: self._data.data.src_data[i][...] for i in self._src_idx_pool}
163-
152+
164153
tgt_indices = sorted({int(j) for i in self._src_idx_pool for j in self._data.data.src_to_tgt_dist_map[i]})
165-
154+
166155
with self._lock:
167156
self._cached_tgts = self._load_targets_parallel(tgt_indices)
168-
157+
169158
return None
170159

171160

@@ -189,7 +178,7 @@ def _sample_source_dist_idx(self, rng) -> int:
189178

190179
def _sample_source_dist_idx_in_memory(self, rng) -> int:
191180
source_idx = rng.choice(sorted(self._cached_srcs.keys()))
192-
self._pool_usage_count[source_idx] = self._pool_usage_count.get(source_idx, 0) + 1
181+
self._pool_usage_count[source_idx] += 1
193182
return source_idx
194183

195184
def _sample_source_dist_idx_in_pool(self, rng) -> int:
@@ -199,70 +188,54 @@ def _sample_source_dist_idx_in_pool(self, rng) -> int:
199188
source_idx = rng.choice(sorted(self._cached_srcs.keys()))
200189

201190
# Increment usage count for monitoring
202-
self._pool_usage_count[source_idx] = self._pool_usage_count.get(source_idx, 0) + 1
191+
self._pool_usage_count[source_idx] += 1
203192

204193
# Gradually replace elements based on replacement probability (schedule only)
205194
if rng.random() < self._replacement_prob:
206195
self._schedule_replacement(rng)
207-
196+
208197
return source_idx
209198

210199
def _schedule_replacement(self, rng):
211200
if self._cache_all:
212201
return # No replacement if everything is cached
213-
214-
# Get usage counts for indices in the pool
215-
pool_indices = self._src_idx_pool.tolist()
216-
usage_counts = np.array([self._pool_usage_count.get(idx, 0) for idx in pool_indices])
217-
218-
if len(usage_counts) == 0:
219-
return
220-
221-
max_usage = usage_counts.max()
222-
most_used_weight = (usage_counts == max_usage).astype(float)
202+
# weights same as previous logic
203+
most_used_weight = (self._pool_usage_count == self._pool_usage_count.max()).astype(float)
223204
if most_used_weight.sum() == 0:
224205
return
225206
most_used_weight /= most_used_weight.sum()
226-
replaced_pool_slot = rng.choice(len(pool_indices), p=most_used_weight)
227-
replaced_pool_idx = pool_indices[replaced_pool_slot]
207+
replaced_pool_idx = rng.choice(self.n_source_dists, p=most_used_weight)
228208

229209
with self._lock:
230-
# If there's already a pending replacement for this pool slot, skip
231-
if replaced_pool_slot in self._pending_replacements:
210+
pool_set = set(self._src_idx_pool.tolist())
211+
if replaced_pool_idx not in pool_set:
232212
return
213+
in_pool_idx = int(np.where(self._src_idx_pool == replaced_pool_idx)[0][0])
233214

234-
# Find all available source indices (not currently in pool)
235-
all_src_indices = list(self._data.data.src_data.keys())
236-
pool_set = set(pool_indices)
237-
available_indices = [idx for idx in all_src_indices if idx not in pool_set]
238-
239-
if not available_indices:
215+
# If there's already a pending replacement for this pool slot, skip
216+
if in_pool_idx in self._pending_replacements:
240217
return
241218

242-
# Get usage counts for available indices
243-
available_usage = np.array([self._pool_usage_count.get(idx, 0) for idx in available_indices])
244-
min_usage = available_usage.min()
245-
least_used_weight = (available_usage == min_usage).astype(float)
219+
least_used_weight = (self._pool_usage_count == self._pool_usage_count.min()).astype(float)
246220
if least_used_weight.sum() == 0:
247221
return
248222
least_used_weight /= least_used_weight.sum()
249-
new_idx_position = rng.choice(len(available_indices), p=least_used_weight)
250-
new_pool_idx = available_indices[new_idx_position]
223+
new_pool_idx = int(rng.choice(self.n_source_dists, p=least_used_weight))
251224

252225
# Kick off background load for new indices
253226
fut: Future = self._executor.submit(self._load_new_cache, new_pool_idx)
254-
self._pending_replacements[replaced_pool_slot] = {
227+
self._pending_replacements[in_pool_idx] = {
255228
"old": replaced_pool_idx,
256229
"new": new_pool_idx,
257230
"future": fut,
258231
}
259-
print(f"scheduled replacement of {replaced_pool_idx} with {new_pool_idx} (slot {replaced_pool_slot})")
232+
print(f"scheduled replacement of {replaced_pool_idx} with {new_pool_idx} (slot {in_pool_idx})")
260233

261234
def _load_targets_parallel(self, tgt_indices):
262235
"""Load multiple target distributions in parallel."""
263236
def _load_tgt(j: int):
264237
return j, self._data.data.tgt_data[j][...]
265-
238+
266239
max_workers = min(32, (os.cpu_count() or 4)) # TODO: avoid magic numbers
267240
with ThreadPoolExecutor(max_workers=max_workers) as ex:
268241
results = list(ex.map(_load_tgt, tgt_indices))

src/scaleflow/data/_datamanager.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -114,31 +114,8 @@ def prepare_data(
114114
tgt_dist_labels = dict(zip(tgt_dist_labels.index, tgt_dist_labels.itertuples(index=False, name=None), strict=True))
115115

116116

117-
# prepare conditions and structure metadata
117+
# prepare conditions
118118
col_to_repr = {key: adata.uns[self.rep_keys[key]] for key in self.rep_keys.keys()}
119-
120-
# Compute condition_structure from first available label
121-
condition_structure = {}
122-
offset = 0
123-
first_src_label = next(iter(src_dist_labels.values()))
124-
first_tgt_label = next(iter(tgt_dist_labels.values()))
125-
126-
for col, label in zip(self.src_dist_keys, first_src_label, strict=True):
127-
if col in col_to_repr:
128-
dim = len(col_to_repr[col][label])
129-
condition_structure[col] = (offset, offset + dim)
130-
offset += dim
131-
132-
for col, label in zip(self.tgt_dist_keys, first_tgt_label, strict=True):
133-
if col in col_to_repr:
134-
dim = len(col_to_repr[col][label])
135-
condition_structure[col] = (offset, offset + dim)
136-
offset += dim
137-
elif isinstance(label, (int, float)):
138-
# Scalar value (like dosage)
139-
condition_structure[col] = (offset, offset + 1)
140-
offset += 1
141-
142119
with timer("Getting conditions", verbose=verbose):
143120
conditions = {}
144121
for src_dist_idx, tgt_dist_idxs in src_to_tgt_dist_map.items():
@@ -183,7 +160,6 @@ def prepare_data(
183160
src_dist_idx_to_labels=src_dist_labels,
184161
tgt_dist_idx_to_labels=tgt_dist_labels,
185162
default_values=default_values,
186-
condition_structure=condition_structure,
187163
),
188164
)
189165

0 commit comments

Comments
 (0)