Skip to content

Commit 3a0034e

Browse files
author
Habib Rehman
committed
adding resolvi
1 parent 778f2b8 commit 3a0034e

2 files changed

Lines changed: 113 additions & 0 deletions

File tree

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
__merge__: /src/api/comp_method_expression_correction.yaml
2+
3+
name: resolvi_correction
4+
label: "resolVI Correction"
5+
summary: "Corrects the expression of genes using resolVI"
6+
description: >-
7+
Corrects the expression of genes based on the resolVI method, a part of scvi-tools.
8+
links:
9+
documentation: "https://docs.scvi-tools.org/en/latest/user_guide/models/resolvi.html"
10+
repository: "https://github.com/scverse/scvi-tools"
11+
references:
12+
doi: "10.1101/2025.01.20.634005"
13+
14+
arguments:
15+
- name: --celltype_key
16+
required: false
17+
direction: input
18+
type: string
19+
default: cell_type
20+
21+
resources:
22+
- type: python_script
23+
path: script.py
24+
25+
engines:
26+
- type: docker
27+
image: openproblems/base_python:1.0.0
28+
__merge__:
29+
- /src/base/setup_txsim_partial.yaml
30+
setup:
31+
- type: python
32+
pypi: [scvi-tools]
33+
- type: native
34+
35+
runners:
36+
- type: executable
37+
- type: nextflow
38+
directives:
39+
label: [ midtime, highcpu, highmem ]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import anndata as ad
2+
import txsim as tx
3+
import scvi
4+
import pandas as pd
5+
import scanpy as sc
6+
import scipy
7+
import numpy as np
8+
9+
## VIASH START
10+
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
11+
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
12+
par = {
13+
'input_spatial_with_cell_types': 'resources_test/task_ist_preprocessing/mouse_brain_combined/spatial_with_cell_types.h5ad',
14+
'celltype_key': 'cell_type',
15+
'output': '../resolvi_spatial_corrected.h5ad',
16+
}
17+
meta = {
18+
'name': 'gene_efficiency_correction',
19+
}
20+
## VIASH END
21+
22+
# Optional parameter check: For this specific correction method the par['input_sc'] is required
23+
24+
# Read input
25+
print('Reading input files', flush=True)
26+
adata_sp = ad.read_h5ad(par['input_spatial_with_cell_types'])
27+
adata_sp.layers["normalized_uncorrected"] = adata_sp.layers["normalized"]
28+
29+
30+
#TODO add resolvi here
31+
32+
print("Filter cells with <5 counts")
33+
sc.pp.filter_cells(adata_sp, min_genes=5)
34+
35+
spatial_array = np.stack([adata_sp.obs['centroid_x'].values, adata_sp.obs['centroid_y'].values], axis=1)
36+
adata_sp.obsm['X_spatial'] = spatial_array
37+
38+
# Apply gene efficiency correction
39+
print('Running ResolVI', flush=True)
40+
41+
scvi.external.RESOLVI.setup_anndata(adata_sp, labels_key=par['celltype_key'], layer="counts")
42+
43+
supervised_resolvi = scvi.external.RESOLVI(adata_sp, semisupervised=True)
44+
45+
supervised_resolvi.train(max_epochs=50)
46+
47+
samples_corr = supervised_resolvi.sample_posterior(
48+
model=supervised_resolvi.module.model_corrected,
49+
return_sites=['px_rate', 'obs'],
50+
num_samples=20, return_samples=False, batch_size=4000) #batch_steps was not a parameter
51+
samples_corr = pd.DataFrame(samples_corr).T
52+
53+
samples = supervised_resolvi.sample_posterior(
54+
model=supervised_resolvi.module.model_residuals,
55+
return_sites=[
56+
'mixture_proportions', 'mean_poisson', 'per_gene_background',
57+
'diffusion_mixture_proportion', 'per_neighbor_diffusion', 'px_r_inv'
58+
],
59+
num_samples=20, return_samples=False, batch_size=4000)
60+
samples = pd.DataFrame(samples).T
61+
62+
63+
adata_sp.obsm["X_resolVI"] = supervised_resolvi.get_latent_representation()
64+
65+
# TODO these 2 lines threw errors because 'obs' was not generated in samples_corr
66+
# adata_sp.layers["generated_expression"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_q25', 'obs'])
67+
# adata_sp.layers["generated_expression_mean"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_means', 'obs'])
68+
69+
adata_sp.layers["corrected_counts"] = adata_sp.layers['counts'].multiply((samples_corr.loc['post_sample_q05', 'px_rate'] / (
70+
1.0 + samples_corr.loc['post_sample_q05', 'px_rate'] + samples.loc['post_sample_means', 'mean_poisson']))).tocsr()
71+
72+
# Write output
73+
print('Writing output', flush=True)
74+
adata_sp.write(par['output'])

0 commit comments

Comments
 (0)