@@ -253,6 +253,7 @@ def _run_incontext_generation(
253253 gene_name_col : Optional [str ],
254254 prompt_ratio : float ,
255255 context_ratio : float ,
256+ context_ratio_min : float ,
256257 mask_rate : float ,
257258 mode : str ,
258259 num_steps : Optional [int ],
@@ -267,6 +268,7 @@ def _run_incontext_generation(
267268 genelist_path = genelist_path ,
268269 prompt_ratio = prompt_ratio ,
269270 context_ratio = context_ratio ,
271+ context_ratio_min = context_ratio_min ,
270272 mask_rate = mask_rate ,
271273 mode = mode ,
272274 num_steps = num_steps ,
@@ -294,7 +296,8 @@ def generate(
294296 split_values : Optional [Sequence [str ]] = None ,
295297 gene_name_col : Optional [str ] = None ,
296298 prompt_ratio : float = 0.25 ,
297- context_ratio : float = 0.25 ,
299+ context_ratio : float = 0.4 ,
300+ context_ratio_min : float = 0.2 ,
298301 mask_rate : float = 1.0 ,
299302 num_steps : Optional [int ] = None ,
300303 mode : str = "vanilla" ,
@@ -351,6 +354,7 @@ def generate(
351354 gene_name_col = gene_name_col ,
352355 prompt_ratio = prompt_ratio ,
353356 context_ratio = context_ratio ,
357+ context_ratio_min = context_ratio_min ,
354358 mask_rate = mask_rate ,
355359 mode = mode ,
356360 num_steps = num_steps ,
@@ -437,7 +441,8 @@ def build_parser() -> argparse.ArgumentParser:
437441 help = "Optional column in adata.var/raw.var containing gene symbols for alignment" ,
438442 )
439443 parser .add_argument ("--prompt-ratio" , type = float , default = 0.25 , help = "Prompt ratio passed to in-context generation" )
440- parser .add_argument ("--context-ratio" , type = float , default = 0.25 , help = "Context ratio passed to in-context generation" )
444+ parser .add_argument ("--context-ratio" , type = float , default = 0.4 , help = "Context ratio passed to in-context generation" )
445+ parser .add_argument ("--context-ratio-min" , type = float , default = 0.2 , help = "Min value of context ratio" )
441446 parser .add_argument ("--mask-rate" , type = float , default = 1.0 , help = "Mask rate used during in-context generation" )
442447 parser .add_argument (
443448 "--num-steps" ,
@@ -484,6 +489,7 @@ def main(args: Optional[List[str]] = None) -> None:
484489 gene_name_col = parsed .gene_name_col ,
485490 prompt_ratio = parsed .prompt_ratio ,
486491 context_ratio = parsed .context_ratio ,
492+ context_ratio_min = parsed .context_ratio_min ,
487493 mask_rate = parsed .mask_rate ,
488494 num_steps = parsed .num_steps ,
489495 mode = parsed .mode ,
@@ -513,6 +519,7 @@ def _stream_save(split_value: str, adata: ad.AnnData) -> None:
513519 gene_name_col = parsed .gene_name_col ,
514520 prompt_ratio = parsed .prompt_ratio ,
515521 context_ratio = parsed .context_ratio ,
522+ context_ratio_min = parsed .context_ratio_min ,
516523 mask_rate = parsed .mask_rate ,
517524 num_steps = parsed .num_steps ,
518525 mode = parsed .mode ,
0 commit comments