@@ -85,12 +85,93 @@ def rechunk_sdata(sdata, CHUNK_SIZE=1024):
8585 sdata .labels [key ] = label_image
8686
8787
88+ def subsample_adata_group_balanced (adata , group_key , n_samples , seed = 0 ):
89+ """Subsample adata to a given number of samples, removing cells from large groups first
90+
91+ Arguments
92+ ---------
93+ adata: anndata.AnnData
94+ The adata to subsample
95+ group_key: str
96+ The key in adata.obs to group by
97+ n_samples: int
98+ The number of samples to subsample to
99+ seed: int
100+ The seed to use for the random subsampling
101+
102+ Returns
103+ -------
104+ pd.Series
105+ The series with the subsample information (boolean, True if the cell is in the subsample).
106+ Series index is the same as adata.obs_names.
107+ """
108+
109+ np .random .seed (seed )
110+
111+ # Get the number of cells per group
112+ n_cells = adata .obs [group_key ].value_counts ().sort_values (ascending = True )
113+
114+ if n_cells .sum () <= n_samples :
115+ all_obs_df = adata .obs .copy ()
116+ all_obs_df ["in_subsample" ] = True
117+ return all_obs_df ["in_subsample" ]
118+
119+ n_celltypes = len (n_cells )
120+
121+ # Find out which groups to subsample from
122+ df = pd .DataFrame ({"n_cells" : n_cells , "sum" : 0 , "n_samples" :0 }, dtype = int )
123+ subsample_from_idx = n_celltypes
124+ tmp = np .zeros (n_celltypes , dtype = int )
125+ for i in range (n_celltypes ):
126+ tmp [i ] = df .iloc [:i ]["n_cells" ].sum ()
127+ tmp [i ] += (n_celltypes - i ) * df .iloc [i ]["n_cells" ]
128+ if tmp [i ] >= n_samples :
129+ subsample_from_idx = i
130+ break
131+ df ["sum" ] = tmp
132+
133+ # Get number of samples per group
134+ n_samples_no_sampling = df .iloc [:subsample_from_idx ]["n_cells" ].sum ()
135+ n_samples_to_subsample = n_samples - n_samples_no_sampling
136+ n_samples_per_group = n_samples_to_subsample // (n_celltypes - subsample_from_idx )
137+ n_samples_per_group_remainder = n_samples_to_subsample % (n_celltypes - subsample_from_idx )
138+ n_samples = np .zeros (n_celltypes , dtype = int )
139+ for i in range (subsample_from_idx ):
140+ n_samples [i ] = df .iloc [i ]["n_cells" ]
141+ for i in range (subsample_from_idx , n_celltypes ):
142+ n_samples [i ] = n_samples_per_group
143+ if n_samples_per_group_remainder > 0 :
144+ n_samples [i ] += 1
145+ n_samples_per_group_remainder -= 1
146+ df ["n_samples" ] = n_samples
147+
148+ # Subsample from the selected groups
149+ mask_df = adata .obs [[group_key ]].copy ()
150+ mask_df ["in_subsample" ] = False
151+ for i in range (subsample_from_idx ):
152+ ct = df .index [i ]
153+ mask_df .loc [mask_df [group_key ] == ct , "in_subsample" ] = True
154+ for i in range (subsample_from_idx , n_celltypes ):
155+ ct = df .index [i ]
156+ ct_obs = mask_df .loc [mask_df [group_key ] == ct ].index
157+ ct_obs_subsample = np .random .choice (ct_obs , size = df .iloc [i ]["n_samples" ], replace = False )
158+ mask_df .loc [ct_obs_subsample , "in_subsample" ] = True
159+
160+ return mask_df ["in_subsample" ]
161+
162+
163+
88164# Load the single-cell data
89165adata = ad .read_h5ad (par ["input_sc" ])
90166
91167# Load the spatial data
92168sdata = sd .read_zarr (par ["input_sp" ])
93169
170+ # Subset single-cell data if it is too large
171+ N_MAX_SC = 120000
172+ if adata .n_obs > N_MAX_SC :
173+ adata = adata [subsample_adata_group_balanced (adata , "cell_type" , N_MAX_SC , seed = 0 )]
174+
94175# Subset single-cell and spatial data to shared genes
95176sp_genes = sdata ['transcripts' ]['feature_name' ].unique ().compute ().tolist ()
96177sc_genes = adata .var ["feature_name" ].unique ().tolist ()
0 commit comments