Skip to content

Commit 80fa04c

Browse files
committed
add samples_per_tar param to control # tar files
1 parent 83fd561 commit 80fa04c

2 files changed

Lines changed: 8 additions & 4 deletions

File tree

seqchromloader/writer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
5555
numProcessors=1,
5656
transforms=None,
5757
braceexpand=False,
58-
DALI=False):
58+
DALI=False,
59+
samples_per_tar=10000):
5960
"""
6061
Given coordinates dataframe, extract the sequence and chromatin signal, save in webdataset format
6162
@@ -83,13 +84,15 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
8384
:param braceexpand: boolean
8485
:param DALI: Set to True if you want to use the dataset for NVIDIA DALI, it would save all arrays in bytes, which results in losing the array shape info
8586
:param DALI: boolean
87+
:param samples_per_tar: Number of samples included per tar file
88+
:param samples_per_tar: int
8689
"""
8790
# check parameters
8891
if (target_bam is not None and target_bw is not None):
8992
raise Exception("Only one of target_bam and target_bw should be provided!")
9093

9194
# split coordinates and assign chunks to workers
92-
num_chunks = math.ceil(len(coords) / 7000)
95+
num_chunks = math.ceil(len(coords) / samples_per_tar)
9396
chunks = np.array_split(coords, num_chunks)
9497

9598
# freeze the common parameters
@@ -131,7 +134,8 @@ def dump_data_webdataset_worker(coords,
131134
outdir="dataset/",
132135
compress=True,
133136
transforms=None,
134-
DALI=False):
137+
DALI=False,
138+
):
135139
# get handlers
136140
genome_pyfaidx = pyfaidx.Fasta(fasta)
137141
bigwigs = [pyBigWig.open(bw) for bw in bigwig_files] if bigwig_files is not None else None

tests/test_writer_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_write_load_target_bam(self):
9292
'score': [".", "."],
9393
'strand': ["+", "+"]
9494
})
95-
huge_coords = pd.concat([coords] * 5000, axis=0).reset_index()
95+
huge_coords = pd.concat([coords] * 6000, axis=0).reset_index()
9696
dump_data_webdataset(huge_coords,
9797
genome_fasta='data/sample.fa',
9898
bigwig_filelist=['data/sample.bw'],

0 commit comments

Comments
 (0)