Skip to content

Commit d7f44b6

Browse files
committed
Add batch_size, remove DALI
add batch_size option to webdataset writer to reduce overhead, remove DALI option
1 parent 186d844 commit d7f44b6

2 files changed

Lines changed: 37 additions & 11 deletions

File tree

seqchromloader/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ def random_coords(gs:str=None, genome:str=None, incl:BedTool=None, excl:BedTool=
172172
.shuffle(seed=seed, **shuffle_kwargs)
173173
.to_dataframe()[["chrom", "start", "end"]])
174174

175+
def motif_scan(motif):
176+
"""Scan the genome for regions that gives a high precision against given motif
177+
178+
:arg1: TODO
179+
:returns: TODO
180+
181+
"""
182+
pass
183+
175184
def chop_genome(chroms:list=None, incl:BedTool=None, excl:BedTool=None, gs=None, genome=None, stride=500, l=500):
176185
"""
177186
Given a genome size file and chromosome list,

seqchromloader/writer.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
5555
numProcessors=1,
5656
transforms=None,
5757
braceexpand=False,
58-
DALI=False,
59-
samples_per_tar=10000):
58+
samples_per_tar=10000,
59+
batch_size=None):
6060
"""
6161
Given coordinates dataframe, extract the sequence and chromatin signal, save in webdataset format
6262
@@ -106,7 +106,7 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
106106
compress=compress,
107107
outdir=outdir,
108108
transforms=transforms,
109-
DALI=DALI)
109+
batch_size=batch_size)
110110

111111
count_of_digits = 0
112112
nc = num_chunks
@@ -134,7 +134,7 @@ def dump_data_webdataset_worker(coords,
134134
outdir="dataset/",
135135
compress=True,
136136
transforms=None,
137-
DALI=False,
137+
batch_size=None,
138138
):
139139
# get handlers
140140
genome_pyfaidx = pyfaidx.Fasta(fasta)
@@ -149,9 +149,8 @@ def dump_data_webdataset_worker(coords,
149149
# iterate all records
150150
filename = os.path.join(outdir, f"{outprefix}.tar.gz" if compress else f"{outprefix}.tar")
151151
sink = wds.TarWriter(filename, compress=compress)
152+
counter = 0; keys = []; seq_arr = []; chrom_arr = []; target_arr = []; label_arr = []
152153
for rindex, item in enumerate(coords.itertuples()):
153-
feature_dict = defaultdict()
154-
feature_dict["__key__"] = f"{rindex}_{item.chrom}:{item.start}-{item.end}_{item.strand}"
155154

156155
try:
157156
feature = utils.extract_info(
@@ -168,17 +167,35 @@ def dump_data_webdataset_worker(coords,
168167
except utils.BigWigInaccessible as e:
169168
continue
170169

171-
if not DALI:
170+
if batch_size is None:
171+
feature_dict = defaultdict()
172+
feature_dict["__key__"] = f"{rindex}_{item.chrom}:{item.start}-{item.end}_{item.strand}"
172173
feature_dict["seq.npy"] = feature['seq']
173174
feature_dict["chrom.npy"] = feature['chrom']
174175
feature_dict["target.npy"] = feature['target']
175176
feature_dict["label.npy"] = feature['label']
177+
sink.write(feature_dict)
176178
else:
177-
feature_dict["seq.npy"] = feature['seq'].tobytes()
178-
feature_dict["chrom.npy"] = feature['chrom'].tobytes()
179-
feature_dict["target.npy"] = feature['target'].tobytes()
180-
feature_dict["label.npy"] = feature['label'].tobytes()
179+
counter += 1
180+
keys.append(f"{rindex}_{item.chrom}:{item.start}-{item.end}_{item.strand}")
181+
seq_arr.append(feature['seq']); chrom_arr.append(feature['chrom']); target_arr.append(feature['target']); label_arr.append(feature['label'])
182+
183+
if counter>=batch_size:
184+
feature_dict = defaultdict()
185+
feature_dict["__key__"] = ','.join(keys)
186+
feature_dict["seq.npy"] = np.array(seq_arr)
187+
feature_dict["chrom.npy"] = np.array(chrom_arr)
188+
feature_dict["target.npy"] = np.array(target_arr)
189+
feature_dict["label.npy"] = np.array(label_arr)
190+
sink.write(feature_dict)
191+
keys, seq_arr, chrom_arr, target_arr, label_arr = [], [], [], [] ,[]
192+
counter = 0
181193

194+
if batch_size is not None:
195+
feature_dict["seq.npy"] = np.array(seq_arr)
196+
feature_dict["chrom.npy"] = np.array(chrom_arr)
197+
feature_dict["target.npy"] = np.array(target_arr)
198+
feature_dict["label.npy"] = np.array(label_arr)
182199
sink.write(feature_dict)
183200

184201
sink.close()

0 commit comments

Comments
 (0)