Skip to content

Commit af2cf62

Browse files
authored
Merge pull request #9 from ArcInstitute/MingzeDong-patch-2
Update generation.py
2 parents 8108e14 + 21cf515 commit af2cf62

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

src/stack/cli/generation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def _run_incontext_generation(
253253
gene_name_col: Optional[str],
254254
prompt_ratio: float,
255255
context_ratio: float,
256+
context_ratio_min: float,
256257
mask_rate: float,
257258
mode: str,
258259
num_steps: Optional[int],
@@ -267,6 +268,7 @@ def _run_incontext_generation(
267268
genelist_path=genelist_path,
268269
prompt_ratio=prompt_ratio,
269270
context_ratio=context_ratio,
271+
context_ratio_min=context_ratio_min,
270272
mask_rate=mask_rate,
271273
mode=mode,
272274
num_steps=num_steps,
@@ -294,7 +296,8 @@ def generate(
294296
split_values: Optional[Sequence[str]] = None,
295297
gene_name_col: Optional[str] = None,
296298
prompt_ratio: float = 0.25,
297-
context_ratio: float = 0.25,
299+
context_ratio: float = 0.4,
300+
context_ratio_min: float = 0.2,
298301
mask_rate: float = 1.0,
299302
num_steps: Optional[int] = None,
300303
mode: str = "vanilla",
@@ -351,6 +354,7 @@ def generate(
351354
gene_name_col=gene_name_col,
352355
prompt_ratio=prompt_ratio,
353356
context_ratio=context_ratio,
357+
context_ratio_min=context_ratio_min,
354358
mask_rate=mask_rate,
355359
mode=mode,
356360
num_steps=num_steps,
@@ -437,7 +441,8 @@ def build_parser() -> argparse.ArgumentParser:
437441
help="Optional column in adata.var/raw.var containing gene symbols for alignment",
438442
)
439443
parser.add_argument("--prompt-ratio", type=float, default=0.25, help="Prompt ratio passed to in-context generation")
440-
parser.add_argument("--context-ratio", type=float, default=0.25, help="Context ratio passed to in-context generation")
444+
parser.add_argument("--context-ratio", type=float, default=0.4, help="Context ratio passed to in-context generation")
445+
parser.add_argument("--context-ratio-min", type=float, default=0.2, help="Min value of context ratio")
441446
parser.add_argument("--mask-rate", type=float, default=1.0, help="Mask rate used during in-context generation")
442447
parser.add_argument(
443448
"--num-steps",
@@ -484,6 +489,7 @@ def main(args: Optional[List[str]] = None) -> None:
484489
gene_name_col=parsed.gene_name_col,
485490
prompt_ratio=parsed.prompt_ratio,
486491
context_ratio=parsed.context_ratio,
492+
context_ratio_min=parsed.context_ratio_min,
487493
mask_rate=parsed.mask_rate,
488494
num_steps=parsed.num_steps,
489495
mode=parsed.mode,
@@ -513,6 +519,7 @@ def _stream_save(split_value: str, adata: ad.AnnData) -> None:
513519
gene_name_col=parsed.gene_name_col,
514520
prompt_ratio=parsed.prompt_ratio,
515521
context_ratio=parsed.context_ratio,
522+
context_ratio_min=parsed.context_ratio_min,
516523
mask_rate=parsed.mask_rate,
517524
num_steps=parsed.num_steps,
518525
mode=parsed.mode,

0 commit comments

Comments
 (0)