Skip to content

Commit ca3a1d6

Browse files
committed
Set maximum size of sc rna seq refs
1 parent b10ce05 commit ca3a1d6

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

src/data_processors/process_dataset/script.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,93 @@ def rechunk_sdata(sdata, CHUNK_SIZE=1024):
8585
sdata.labels[key] = label_image
8686

8787

88+
def subsample_adata_group_balanced(adata, group_key, n_samples, seed=0):
89+
"""Subsample adata to a given number of samples, removing cells from large groups first
90+
91+
Arguments
92+
---------
93+
adata: anndata.AnnData
94+
The adata to subsample
95+
group_key: str
96+
The key in adata.obs to group by
97+
n_samples: int
98+
The number of samples to subsample to
99+
seed: int
100+
The seed to use for the random subsampling
101+
102+
Returns
103+
-------
104+
pd.Series
105+
The series with the subsample information (boolean, True if the cell is in the subsample).
106+
Series index is the same as adata.obs_names.
107+
"""
108+
109+
np.random.seed(seed)
110+
111+
# Get the number of cells per group
112+
n_cells = adata.obs[group_key].value_counts().sort_values(ascending=True)
113+
114+
if n_cells.sum() <= n_samples:
115+
all_obs_df = adata.obs.copy()
116+
all_obs_df["in_subsample"] = True
117+
return all_obs_df["in_subsample"]
118+
119+
n_celltypes = len(n_cells)
120+
121+
# Find out which groups to subsample from
122+
df = pd.DataFrame({"n_cells": n_cells, "sum": 0, "n_samples":0}, dtype=int)
123+
subsample_from_idx = n_celltypes
124+
tmp = np.zeros(n_celltypes, dtype=int)
125+
for i in range(n_celltypes):
126+
tmp[i] = df.iloc[:i]["n_cells"].sum()
127+
tmp[i] += (n_celltypes - i) * df.iloc[i]["n_cells"]
128+
if tmp[i] >= n_samples:
129+
subsample_from_idx = i
130+
break
131+
df["sum"] = tmp
132+
133+
# Get number of samples per group
134+
n_samples_no_sampling = df.iloc[:subsample_from_idx]["n_cells"].sum()
135+
n_samples_to_subsample = n_samples - n_samples_no_sampling
136+
n_samples_per_group = n_samples_to_subsample // (n_celltypes - subsample_from_idx)
137+
n_samples_per_group_remainder = n_samples_to_subsample % (n_celltypes - subsample_from_idx)
138+
n_samples = np.zeros(n_celltypes, dtype=int)
139+
for i in range(subsample_from_idx):
140+
n_samples[i] = df.iloc[i]["n_cells"]
141+
for i in range(subsample_from_idx, n_celltypes):
142+
n_samples[i] = n_samples_per_group
143+
if n_samples_per_group_remainder > 0:
144+
n_samples[i] += 1
145+
n_samples_per_group_remainder -= 1
146+
df["n_samples"] = n_samples
147+
148+
# Subsample from the selected groups
149+
mask_df = adata.obs[[group_key]].copy()
150+
mask_df["in_subsample"] = False
151+
for i in range(subsample_from_idx):
152+
ct = df.index[i]
153+
mask_df.loc[mask_df[group_key] == ct, "in_subsample"] = True
154+
for i in range(subsample_from_idx, n_celltypes):
155+
ct = df.index[i]
156+
ct_obs = mask_df.loc[mask_df[group_key] == ct].index
157+
ct_obs_subsample = np.random.choice(ct_obs, size=df.iloc[i]["n_samples"], replace=False)
158+
mask_df.loc[ct_obs_subsample, "in_subsample"] = True
159+
160+
return mask_df["in_subsample"]
161+
162+
163+
88164
# Load the single-cell data
89165
adata = ad.read_h5ad(par["input_sc"])
90166

91167
# Load the spatial data
92168
sdata = sd.read_zarr(par["input_sp"])
93169

170+
# Subset single-cell data if it is too large
171+
N_MAX_SC = 120000
172+
if adata.n_obs > N_MAX_SC:
173+
adata = adata[subsample_adata_group_balanced(adata, "cell_type", N_MAX_SC, seed=0)]
174+
94175
# Subset single-cell and spatial data to shared genes
95176
sp_genes = sdata['transcripts']['feature_name'].unique().compute().tolist()
96177
sc_genes = adata.var["feature_name"].unique().tolist()

0 commit comments

Comments
 (0)