33import scvi
44from scipy .sparse import issparse , csr_matrix , csc_matrix
55import muon
6+ import scanpy as sc
67
78
89def get_representation (
9- adata : ad .AnnData , modality : Literal ["GEX" , "ADT" , "ATAC" ], use_hvg : bool = True , adt_normalization : Literal ["clr" , "log_cp10k" ] = "clr" ) -> ad .AnnData :
10+ adata : ad .AnnData ,
11+ modality : Literal ["GEX" , "ADT" , "ATAC" ],
12+ use_hvg : bool = True ,
13+ adt_normalization : Literal ["clr" , "log_cp10k" ] = "clr" ,
14+ plot_umap : bool = False ,
15+ ) -> ad .AnnData :
1016 """
1117 Get a joint latent space representation of the data based on the modality.
1218
@@ -29,6 +35,9 @@ def get_representation(
2935 Normalization method for ADT data. Options are:
3036 - "clr" (centered log-ratio transformation)
3137 - "log_cp10k" (normalization to 10k counts per cell and logarithm transformation)
38+ plot_umap
39+ Purely for diagnostic purposes, to see whether the data integration looks ok, this optionally computes
40+ a UMAP in shared latent space and stores a plot.
3241
3342 Returns
3443 -------
@@ -46,8 +55,9 @@ def get_representation(
4655 # Setup the AnnData object for scVI
4756 if modality == "GEX" :
4857 layer = "counts"
49- scvi .model .SCVI .setup_anndata (adata , batch_key = "batch" , layer = layer )
50- model = scvi .model .SCVI (adata , gene_likelihood = "nb" , n_layers = 2 , n_latent = 30 )
58+ scvi .model .SCVI .setup_anndata (adata , layer = layer , categorical_covariate_keys = ["split" , "batch" ])
59+ model = scvi .model .SCVI (adata )
60+
5161 elif modality == "ADT" :
5262 print (f"Normalizing the ADT data using method '{ adt_normalization } '" )
5363 if adt_normalization == "clr" :
@@ -60,11 +70,11 @@ def get_representation(
6070 raise ValueError (f"Unknown ADT normalization method: { adt_normalization } " )
6171
6272 layer = "adt_normalized"
63- scvi .model .SCVI .setup_anndata (adata , batch_key = "batch " , layer = layer )
73+ scvi .model .SCVI .setup_anndata (adata , layer = layer , categorical_covariate_keys = [ "split " , "batch" ] )
6474 model = scvi .model .SCVI (adata , gene_likelihood = "normal" , n_layers = 1 , n_latent = 10 )
6575 elif modality == "ATAC" :
6676 layer = "counts"
67- scvi .model .PEAKVI .setup_anndata (adata , batch_key = "batch " , layer = layer )
77+ scvi .model .PEAKVI .setup_anndata (adata , layer = layer , categorical_covariate_keys = [ "split " , "batch" ] )
6878 model = scvi .model .PEAKVI (adata )
6979 else :
7080 raise ValueError (f"Unknown modality: { modality } " )
@@ -80,4 +90,11 @@ def get_representation(
8090 # Get the latent representation
8191 adata .obsm ["X_scvi" ] = model .get_latent_representation ()
8292
93+ if plot_umap :
94+ sc .pp .neighbors (adata , use_rep = "X_scvi" )
95+ sc .tl .umap (adata )
96+
97+ plot_name = f"_{ modality } _{ adt_normalization } _use_hvg_{ use_hvg } .png"
98+ sc .pl .embedding (adata , basis = "umap" , color = ["batch" , "split" ], show = False , save = plot_name )
99+
83100 return adata
0 commit comments