2323logger = logging .getLogger (__name__ )
2424
2525
26- def tokenize_data (nproc , model_details = None , gene_median = None , token = None , gene_mapping_file = None , tokenized_dir = None ):
26+ def tokenize_data (nproc , temp_dir , model_details = None , gene_median = None , token = None , gene_mapping_file = None , tokenized_dir = None ):
2727 """Tokenize data with required parameters"""
2828 if not all ([model_details , gene_median , token , gene_mapping_file , tokenized_dir ]):
2929 raise ValueError ("Missing required parameters for tokenization" )
@@ -37,7 +37,7 @@ def tokenize_data(nproc, model_details=None, gene_median=None, token=None, gene_
3737 )
3838
3939 tokenizer .tokenize_data (
40- "/tmp/geneformer/" , tokenized_dir , "tokenized" , file_format = "h5ad"
40+ temp_dir , tokenized_dir , "tokenized" , file_format = "h5ad"
4141 )
4242# extract embeddings
4343def get_embs (
@@ -970,6 +970,7 @@ def tryParallelFunction(func, label, **kwargs):
970970
971971def compute_geneformer_network (
972972 adata ,
973+ temp_dir ,
973974 forward_batch_size = 4 ,
974975 max_ncells = 1000 ,
975976 n_processors = 20 ,
@@ -990,19 +991,19 @@ def compute_geneformer_network(
990991 ]
991992 adata .obs ["n_counts" ] = adata .X .sum (1 )
992993 # Create the geneformer folder if it doesn't exist
993- geneformer_folder = "/tmp /geneformer"
994+ geneformer_folder = f" { temp_dir } /geneformer"
994995 if not os .path .exists (geneformer_folder ):
995996 os .makedirs (geneformer_folder )
996- adata .write_h5ad ("/tmp /geneformer/to_token.h5ad" )
997+ adata .write_h5ad (f" { temp_dir } /geneformer/to_token.h5ad" )
997998
998999 genelist = [gene_mapping_dict [u ] for u in adata .var .index ]
9991000
1000- tokenized_data_path = "/tmp /geneformer/tokenized_data.dataset"
1001+ tokenized_data_path = f" { temp_dir } /geneformer/tokenized_data.dataset"
10011002 if os .path .exists (tokenized_data_path ):
10021003 shutil .rmtree (tokenized_data_path )
10031004
10041005 # Note: This would need proper model_details, gene_median, gene_mapping_file parameters
1005- # tryParallelFunction(tokenize_data, "Tokenizing data")
1006+ tryParallelFunction (tokenize_data , "Tokenizing data" , temp_dir = geneformer_folder )
10061007
10071008 embex = EmbExtractor (
10081009 model_type = "Pretrained" , # CellClassifier
@@ -1153,6 +1154,7 @@ def main(par):
11531154 ]
11541155 subadata , net = compute_geneformer_network (
11551156 subadata ,
1157+ temp_dir = par ["temp_dir" ],
11561158 forward_batch_size = par ["batch_size" ],
11571159 n_processors = n_processors ,
11581160 max_ncells = par ["max_cells" ],
0 commit comments