Skip to content

Commit 3bbfdb7

Browse files
committed
add: test function for convert_data_webdataset
1 parent d91cc7a commit 3bbfdb7

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

tests/test_writer_loader.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
sys.path.insert(0, "./")
44
import pandas as pd
55
from seqchromloader import SeqChromDatasetByDataFrame, SeqChromDatasetByBed, SeqChromDatasetByWds, SeqChromDataModule
6-
from seqchromloader import dump_data_webdataset
6+
from seqchromloader import dump_data_webdataset, convert_data_webdataset
77

88
import unittest
99
import tempfile
@@ -85,6 +85,20 @@ def test_wds_loader_transform(self):
8585
self.assertEqual(target[0].item(), 0.0)
8686
self.assertEqual(label[1].item(), 1)
8787

88+
def test_wds_convert_loader(self):
89+
convert_data_webdataset("data/test_0.tar.gz", "test_0_convert.tar.gz",
90+
transforms={"seq": test_seq_transform,
91+
"chrom": test_chrom_transform,
92+
"target": test_target_transform})
93+
it = iter(SeqChromDatasetByWds(["test_0_convert.tar.gz"],
94+
dataloader_kws={"batch_size":3}))
95+
seq, chrom, target, label = next(it)
96+
97+
self.assertEqual(seq[0,0,3].item(), 2.0)
98+
self.assertAlmostEqual(chrom[0,0,3].item(), 4.0/3)
99+
self.assertEqual(target[0].item(), 0.0)
100+
self.assertEqual(label[1].item(), 1)
101+
88102
def test_df_loader(self):
89103
dataframe = pd.read_table("data/sample.bed", header=None, sep="\t", names=['chrom', 'start', 'end', 'label', 'score', 'strand' ])
90104
it = iter(SeqChromDatasetByDataFrame(

0 commit comments

Comments
 (0)