Skip to content

Commit 828bdc3

Browse files
author
Habib Rehman
committed
Added parameters for resolvi
1 parent 3a0034e commit 828bdc3

2 files changed

Lines changed: 29 additions & 8 deletions

File tree

src/methods_expression_correction/resolvi_correction/config.vsh.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@ arguments:
1818
type: string
1919
default: cell_type
2020

21+
- name: --n_hidden
22+
required: false
23+
direction: input
24+
type: int
25+
default: 32
26+
27+
- name: --encode_covariates
28+
required: false
29+
direction: input
30+
type: boolean
31+
default: false
32+
33+
- name: --downsample_counts
34+
required: false
35+
direction: input
36+
type: boolean
37+
default: true
38+
2139
resources:
2240
- type: python_script
2341
path: script.py

src/methods_expression_correction/resolvi_correction/script.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
'input_spatial_with_cell_types': 'resources_test/task_ist_preprocessing/mouse_brain_combined/spatial_with_cell_types.h5ad',
1414
'celltype_key': 'cell_type',
1515
'output': '../resolvi_spatial_corrected.h5ad',
16+
'n_hidden': 32,
17+
'encode_covariates': False,
18+
'downsample_counts': True
1619
}
1720
meta = {
1821
'name': 'gene_efficiency_correction',
@@ -26,9 +29,6 @@
2629
adata_sp = ad.read_h5ad(par['input_spatial_with_cell_types'])
2730
adata_sp.layers["normalized_uncorrected"] = adata_sp.layers["normalized"]
2831

29-
30-
#TODO add resolvi here
31-
3232
print("Filter cells with <5 counts")
3333
sc.pp.filter_cells(adata_sp, min_genes=5)
3434

@@ -40,13 +40,16 @@
4040

4141
scvi.external.RESOLVI.setup_anndata(adata_sp, labels_key=par['celltype_key'], layer="counts")
4242

43-
supervised_resolvi = scvi.external.RESOLVI(adata_sp, semisupervised=True)
44-
43+
supervised_resolvi = scvi.external.RESOLVI(adata_sp, semisupervised=True,
44+
n_hidden = par['n_hidden'],
45+
encode_covariates = par['encode_covariates'],
46+
downsample_counts = par['downsample_counts'])
4547
supervised_resolvi.train(max_epochs=50)
4648

4749
samples_corr = supervised_resolvi.sample_posterior(
4850
model=supervised_resolvi.module.model_corrected,
49-
return_sites=['px_rate', 'obs'],
51+
return_sites=['px_rate'],
52+
summary_fun={"post_sample_q50": np.median},
5053
num_samples=20, return_samples=False, batch_size=4000) #batch_steps was not a parameter
5154
samples_corr = pd.DataFrame(samples_corr).T
5255

@@ -66,8 +69,8 @@
6669
# adata_sp.layers["generated_expression"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_q25', 'obs'])
6770
# adata_sp.layers["generated_expression_mean"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_means', 'obs'])
6871

69-
adata_sp.layers["corrected_counts"] = adata_sp.layers['counts'].multiply((samples_corr.loc['post_sample_q05', 'px_rate'] / (
70-
1.0 + samples_corr.loc['post_sample_q05', 'px_rate'] + samples.loc['post_sample_means', 'mean_poisson']))).tocsr()
72+
adata_sp.layers["corrected_counts"] = adata_sp.layers['counts'].multiply((samples_corr.loc['post_sample_q50', 'px_rate'] / (
73+
1.0 + samples_corr.loc['post_sample_q50', 'px_rate'] + samples.loc['post_sample_means', 'mean_poisson']))).tocsr()
7174

7275
# Write output
7376
print('Writing output', flush=True)

0 commit comments

Comments
 (0)