@@ -55,8 +55,8 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
5555 numProcessors = 1 ,
5656 transforms = None ,
5757 braceexpand = False ,
58- DALI = False ,
59- samples_per_tar = 10000 ):
58+ samples_per_tar = 10000 ,
59+ batch_size = None ):
6060 """
6161 Given coordinates dataframe, extract the sequence and chromatin signal, save in webdataset format
6262
@@ -106,7 +106,7 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
106106 compress = compress ,
107107 outdir = outdir ,
108108 transforms = transforms ,
109- DALI = DALI )
109+ batch_size = batch_size )
110110
111111 count_of_digits = 0
112112 nc = num_chunks
@@ -134,7 +134,7 @@ def dump_data_webdataset_worker(coords,
134134 outdir = "dataset/" ,
135135 compress = True ,
136136 transforms = None ,
137- DALI = False ,
137+ batch_size = None ,
138138 ):
139139 # get handlers
140140 genome_pyfaidx = pyfaidx .Fasta (fasta )
@@ -149,9 +149,8 @@ def dump_data_webdataset_worker(coords,
149149 # iterate all records
150150 filename = os .path .join (outdir , f"{ outprefix } .tar.gz" if compress else f"{ outprefix } .tar" )
151151 sink = wds .TarWriter (filename , compress = compress )
152+ counter = 0 ; keys = []; seq_arr = []; chrom_arr = []; target_arr = []; label_arr = []
152153 for rindex , item in enumerate (coords .itertuples ()):
153- feature_dict = defaultdict ()
154- feature_dict ["__key__" ] = f"{ rindex } _{ item .chrom } :{ item .start } -{ item .end } _{ item .strand } "
155154
156155 try :
157156 feature = utils .extract_info (
@@ -168,17 +167,35 @@ def dump_data_webdataset_worker(coords,
168167 except utils .BigWigInaccessible as e :
169168 continue
170169
171- if not DALI :
170+ if batch_size is None :
171+ feature_dict = defaultdict ()
172+ feature_dict ["__key__" ] = f"{ rindex } _{ item .chrom } :{ item .start } -{ item .end } _{ item .strand } "
172173 feature_dict ["seq.npy" ] = feature ['seq' ]
173174 feature_dict ["chrom.npy" ] = feature ['chrom' ]
174175 feature_dict ["target.npy" ] = feature ['target' ]
175176 feature_dict ["label.npy" ] = feature ['label' ]
177+ sink .write (feature_dict )
176178 else :
177- feature_dict ["seq.npy" ] = feature ['seq' ].tobytes ()
178- feature_dict ["chrom.npy" ] = feature ['chrom' ].tobytes ()
179- feature_dict ["target.npy" ] = feature ['target' ].tobytes ()
180- feature_dict ["label.npy" ] = feature ['label' ].tobytes ()
179+ counter += 1
180+ keys .append (f"{ rindex } _{ item .chrom } :{ item .start } -{ item .end } _{ item .strand } " )
181+ seq_arr .append (feature ['seq' ]); chrom_arr .append (feature ['chrom' ]); target_arr .append (feature ['target' ]); label_arr .append (feature ['label' ])
182+
183+ if counter >= batch_size :
184+ feature_dict = defaultdict ()
185+ feature_dict ["__key__" ] = ',' .join (keys )
186+ feature_dict ["seq.npy" ] = np .array (seq_arr )
187+ feature_dict ["chrom.npy" ] = np .array (chrom_arr )
188+ feature_dict ["target.npy" ] = np .array (target_arr )
189+ feature_dict ["label.npy" ] = np .array (label_arr )
190+ sink .write (feature_dict )
191+ keys , seq_arr , chrom_arr , target_arr , label_arr = [], [], [], [] ,[]
192+ counter = 0
181193
194+ if batch_size is not None :
195+ feature_dict ["seq.npy" ] = np .array (seq_arr )
196+ feature_dict ["chrom.npy" ] = np .array (chrom_arr )
197+ feature_dict ["target.npy" ] = np .array (target_arr )
198+ feature_dict ["label.npy" ] = np .array (label_arr )
182199 sink .write (feature_dict )
183200
184201 sink .close ()
0 commit comments