Skip to content

Commit 14e4662

Browse files
committed
add: convert_data_webdataset to transform an existing wds file
1 parent c38e9d9 commit 14e4662

3 files changed

Lines changed: 38 additions & 5 deletions

File tree

seqchromloader/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .loader import SeqChromDatasetByDataFrame, SeqChromDatasetByBed, SeqChromDatasetByWds, SeqChromDataModule
2-
from .writer import dump_data_webdataset
2+
from .writer import dump_data_webdataset, convert_data_webdataset

seqchromloader/loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@ class _SeqChromDatasetByWds(IterableDataset):
5757
:param transforms: A dictionary of functions to transform the output data, accepted keys are **["seq", "chrom", "target", "label"]**
5858
:type transforms: dict of functions
5959
"""
60-
def __init__(self, wds, transforms:dict=None, rank=0, world_size=1):
60+
def __init__(self, wds, transforms:dict=None, rank=0, world_size=1, keep_key=False):
6161
self.wds = wds
6262
self.transforms = transforms
6363

6464
self.rank = rank
6565
self.world_size = world_size
66+
self.keep_key = keep_key
6667

6768
def initialize(self):
6869
# this function will be called by worker_init_function in DataLoader
@@ -85,7 +86,10 @@ def __iter__(self):
8586
if self.transforms is not None:
8687
pipeline.append(wds.map_dict(**self.transforms))
8788

88-
pipeline.append(wds.to_tuple("seq", "chrom", "target", "label"))
89+
if self.keep_key:
90+
pipeline.append(wds.to_tuple("__key__", "seq", "chrom", "target", "label"))
91+
else:
92+
pipeline.append(wds.to_tuple("seq", "chrom", "target", "label"))
8993

9094
ds = wds.DataPipeline(*pipeline)
9195

seqchromloader/writer.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,43 @@
1717
import pyBigWig
1818
import webdataset as wds
1919

20-
from seqchromloader import utils
20+
from . import utils
21+
from .loader import _SeqChromDatasetByWds
2122

23+
def convert_data_webdataset(wds_in, wds_out, transforms=None, compress=False):
24+
"""
25+
Transform the provided webdataset
26+
27+
:param wds_in: input webdataset file
28+
:type wds_in: string
29+
:param wds_out: output webdataset file
30+
:type wds_out: string
31+
:param transforms: A dictionary of functions to transform the output data, accepted keys are *["seq", "chrom", "target", "label"]*
32+
:type transforms: dict of functions
33+
:param compress: whether to compress the output file
34+
:type compress: boolean
35+
"""
36+
37+
ds = _SeqChromDatasetByWds(wds_in, transforms=transforms, keep_key=True)
38+
sink = wds.TarWriter(wds_out, compress=compress)
39+
for (key, seq, chrom, target, label) in ds:
40+
feature_dict = defaultdict()
41+
feature_dict["__key__"] = key
42+
43+
feature_dict["seq.npy"] = seq
44+
feature_dict["chrom.npy"] = chrom
45+
feature_dict["target.npy"] = target
46+
feature_dict["label.npy"] = label
47+
sink.write(feature_dict)
48+
sink.close()
49+
2250
def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
2351
target_bam=None,
2452
outdir="dataset/", outprefix="seqchrom",
2553
compress=True,
2654
numProcessors=1,
27-
transforms=None):
55+
transforms=None,
56+
DALI=False):
2857
"""
2958
Given coordinates dataframe, extract the sequence and chromatin signal, save in webdataset format
3059

0 commit comments

Comments
 (0)