@@ -51,6 +51,7 @@ def __init__(
5151 is_log1p : bool = True ,
5252 cell_sentence_len : int | None = None ,
5353 h5_open_kwargs : dict | None = None ,
54+ collate_dtype : str = "float16" ,
5455 ** kwargs ,
5556 ):
5657 """
@@ -78,6 +79,8 @@ def __init__(
7879 is_log1p: Whether raw counts in X are log1p-transformed (default True; affects downsampling)
7980 cell_sentence_len: Optional sentence length for consecutive loading batches
8081 h5_open_kwargs: Optional kwargs to pass to h5py.File (e.g., rdcc_nbytes)
82+ collate_dtype: dtype for tensor outputs — "float16", "float32", or "bfloat16".
83+ Casting to float16 before collation halves per-sample memory in workers and pinned memory.
8184 **kwargs: Additional options (e.g. output_space)
8285 """
8386 super ().__init__ ()
@@ -121,6 +124,9 @@ def __init__(
121124 self .h5_open_kwargs = self ._normalize_h5_open_kwargs (h5_open_kwargs )
122125 self .additional_obs = self ._validate_additional_obs (additional_obs )
123126
127+ _dtype_map = {"float16" : torch .float16 , "float32" : torch .float32 , "bfloat16" : torch .bfloat16 }
128+ self .collate_dtype = _dtype_map .get (collate_dtype , torch .float32 )
129+
124130 # Load metadata cache and open file
125131 self .metadata_cache = GlobalH5MetadataCache ().get_cache (
126132 str (self .h5_path ), pert_col , cell_type_key , control_pert , batch_col
@@ -346,6 +352,12 @@ def __getitem__(self, idx: int):
346352 elif self .output_space == "all" :
347353 sample ["ctrl_cell_counts" ] = self .fetch_gene_expression (ctrl_idx )
348354
355+ # Cast tensor values to collate_dtype to reduce worker/pinned memory
356+ if self .collate_dtype != torch .float32 :
357+ for k in ("pert_cell_emb" , "ctrl_cell_emb" , "pert_cell_counts" , "ctrl_cell_counts" ):
358+ if isinstance (sample .get (k ), torch .Tensor ):
359+ sample [k ] = sample [k ].to (self .collate_dtype )
360+
349361 # Optionally include cell barcodes
350362 if self .barcode and self .cell_barcodes is not None :
351363 sample ["pert_cell_barcode" ] = self .cell_barcodes [file_idx ]
@@ -483,6 +495,15 @@ def __getitems__(self, indices):
483495 else :
484496 ctrl_counts_batch = ctrl_expr_batch
485497
498+ # Cast batch tensors to collate_dtype to reduce worker/pinned memory
499+ if self .collate_dtype != torch .float32 :
500+ pert_expr_batch = pert_expr_batch .to (self .collate_dtype )
501+ ctrl_expr_batch = ctrl_expr_batch .to (self .collate_dtype )
502+ if pert_counts_batch is not None :
503+ pert_counts_batch = pert_counts_batch .to (self .collate_dtype )
504+ if ctrl_counts_batch is not None :
505+ ctrl_counts_batch = ctrl_counts_batch .to (self .collate_dtype )
506+
486507 samples = []
487508 for i , file_idx in enumerate (file_indices ):
488509 pert_expr = pert_expr_batch [i ]
0 commit comments