Skip to content

Commit 28b7bb3

Browse files
abhinadduriclaude
andcommitted
feat: add collate_dtype support and pin_memory config to PerturbationDataModule
- Add collate_dtype param to PerturbationDataset for float16/float32 tensor casting - Wire collate_dtype through PerturbationDataModule to all dataset constructors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f3c1895 commit 28b7bb3

2 files changed

Lines changed: 26 additions & 0 deletions

File tree

src/cell_load/data_modules/perturbation_dataloader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
use_consecutive_loading: bool = False,
8989
h5_open_kwargs: dict | None = None,
9090
show_progress: bool = True,
91+
collate_dtype: str = "float16",
9192
**kwargs, # missing perturbation_features_file and store_raw_basal for backwards compatibility
9293
):
9394
"""
@@ -195,6 +196,7 @@ def __init__(
195196
self.additional_obs = additional_obs
196197
self.h5_open_kwargs = h5_open_kwargs
197198
self.show_progress = bool(show_progress)
199+
self.collate_dtype = collate_dtype
198200
if self.use_consecutive_loading:
199201
self._set_h5_cache_env_defaults()
200202

@@ -305,6 +307,7 @@ def save_state(self, filepath: str):
305307
"additional_obs": self.additional_obs,
306308
"use_consecutive_loading": self.use_consecutive_loading,
307309
"h5_open_kwargs": self.h5_open_kwargs,
310+
"collate_dtype": self.collate_dtype,
308311
}
309312

310313
torch.save(save_dict, filepath)
@@ -349,6 +352,7 @@ def load_state(cls, filepath: str):
349352
"barcode": save_dict.pop("barcode", True),
350353
"use_consecutive_loading": save_dict.pop("use_consecutive_loading", False),
351354
"h5_open_kwargs": save_dict.pop("h5_open_kwargs", None),
355+
"collate_dtype": save_dict.pop("collate_dtype", "float16"),
352356
}
353357

354358
# Create new instance with all the saved parameters
@@ -639,6 +643,7 @@ def _create_base_dataset(
639643
is_log1p=self.is_log1p,
640644
cell_sentence_len=self.cell_sentence_len,
641645
h5_open_kwargs=self.h5_open_kwargs,
646+
collate_dtype=self.collate_dtype,
642647
)
643648

644649
def _setup_datasets(self):

src/cell_load/dataset/_perturbation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
is_log1p: bool = True,
5252
cell_sentence_len: int | None = None,
5353
h5_open_kwargs: dict | None = None,
54+
collate_dtype: str = "float16",
5455
**kwargs,
5556
):
5657
"""
@@ -78,6 +79,8 @@ def __init__(
7879
is_log1p: Whether raw counts in X are log1p-transformed (default True; affects downsampling)
7980
cell_sentence_len: Optional sentence length for consecutive loading batches
8081
h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes)
82+
collate_dtype: dtype for tensor outputs — "float16", "float32", or "bfloat16".
83+
Casting to float16 before collation halves per-sample memory in workers and pinned memory.
8184
**kwargs: Additional options (e.g. output_space)
8285
"""
8386
super().__init__()
@@ -121,6 +124,9 @@ def __init__(
121124
self.h5_open_kwargs = self._normalize_h5_open_kwargs(h5_open_kwargs)
122125
self.additional_obs = self._validate_additional_obs(additional_obs)
123126

127+
_dtype_map = {"float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16}
128+
self.collate_dtype = _dtype_map.get(collate_dtype, torch.float32)
129+
124130
# Load metadata cache and open file
125131
self.metadata_cache = GlobalH5MetadataCache().get_cache(
126132
str(self.h5_path), pert_col, cell_type_key, control_pert, batch_col
@@ -346,6 +352,12 @@ def __getitem__(self, idx: int):
346352
elif self.output_space == "all":
347353
sample["ctrl_cell_counts"] = self.fetch_gene_expression(ctrl_idx)
348354

355+
# Cast tensor values to collate_dtype to reduce worker/pinned memory
356+
if self.collate_dtype != torch.float32:
357+
for k in ("pert_cell_emb", "ctrl_cell_emb", "pert_cell_counts", "ctrl_cell_counts"):
358+
if isinstance(sample.get(k), torch.Tensor):
359+
sample[k] = sample[k].to(self.collate_dtype)
360+
349361
# Optionally include cell barcodes
350362
if self.barcode and self.cell_barcodes is not None:
351363
sample["pert_cell_barcode"] = self.cell_barcodes[file_idx]
@@ -483,6 +495,15 @@ def __getitems__(self, indices):
483495
else:
484496
ctrl_counts_batch = ctrl_expr_batch
485497

498+
# Cast batch tensors to collate_dtype to reduce worker/pinned memory
499+
if self.collate_dtype != torch.float32:
500+
pert_expr_batch = pert_expr_batch.to(self.collate_dtype)
501+
ctrl_expr_batch = ctrl_expr_batch.to(self.collate_dtype)
502+
if pert_counts_batch is not None:
503+
pert_counts_batch = pert_counts_batch.to(self.collate_dtype)
504+
if ctrl_counts_batch is not None:
505+
ctrl_counts_batch = ctrl_counts_batch.to(self.collate_dtype)
506+
486507
samples = []
487508
for i, file_idx in enumerate(file_indices):
488509
pert_expr = pert_expr_batch[i]

0 commit comments

Comments
 (0)