1+ import os
2+ import pandas as pd
3+ from seqchromloader import SeqChromDatasetByBed , SeqChromDatasetByWds , SeqChromDataModule
4+ from seqchromloader import dump_data_webdataset
5+
6+ import unittest
7+ import tempfile
8+ import shutil
9+ import pathlib as pl
10+ import webdataset as wds
11+
12+ class Test (unittest .TestCase ):
13+ def setUp (self ) -> None :
14+ pass
15+
16+ def tearDown (self ) -> None :
17+ pass
18+
19+ @classmethod
20+ def setUpClass (cls ) -> None :
21+ cls .tempdir = tempfile .mkdtemp ()
22+
23+ @classmethod
24+ def tearDownClass (cls ) -> None :
25+ shutil .rmtree (cls .tempdir )
26+
27+ def assertIsFile (self , path ):
28+ if not pl .Path (path ).resolve ().is_file ():
29+ raise AssertionError ("File does not exist: %s" % str (path ))
30+
31+ def test_writer (self ):
32+ coords = pd .DataFrame ({
33+ 'chrom' : ["chr19" , "chr19" ],
34+ 'start' : [0 , 3 ],
35+ 'end' : [5 , 8 ],
36+ 'label' : [0 , 1 ],
37+ 'score' : ["." , "." ],
38+ 'strand' : ["+" , "+" ]
39+ })
40+ huge_coords = pd .concat ([coords ] * 5000 , axis = 0 ).reset_index ()
41+ dump_data_webdataset (huge_coords ,
42+ genome_fasta = 'data/sample.fa' ,
43+ bigwig_filelist = ['data/sample.bw' ],
44+ target_bam = 'data/sample.bam' ,
45+ outdir = self .tempdir ,
46+ outprefix = 'test' ,
47+ compress = True ,
48+ numProcessors = 5 )
49+ self .assertIsFile (os .path .join (self .tempdir , "test_0.tar.gz" ))
50+
51+ ds = wds .DataPipeline (
52+ wds .SimpleShardList ([os .path .join (self .tempdir , "test_0.tar.gz" )]),
53+ wds .tarfile_to_samples (),
54+ wds .decode (),
55+ wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
56+ wds .batched (2 )
57+ )
58+ seq , chrom , target , label = next (iter (ds ))
59+ self .assertEqual (seq [1 ,0 ,4 ].item (), 1.0 )
60+ self .assertEqual (chrom [0 ,0 ,3 ].item (), 999.0 )
61+ self .assertEqual (target [0 ].item (), 2.0 )
62+ self .assertEqual (label [1 ].item (), 1 )
63+
64+ def test_wds_loader (self ):
65+ it = iter (SeqChromDatasetByWds (["data/test_0.tar.gz" ], dataloader_kws = {"batch_size" :3 }))
66+ seq , chrom , target , label = next (it )
67+
68+ self .assertEqual (seq [0 ,0 ,3 ].item (), 1.0 )
69+ self .assertEqual (chrom [0 ,0 ,3 ].item (), 4.0 )
70+ self .assertEqual (target [0 ].item (), 0.0 )
71+ self .assertEqual (label [1 ].item (), 1 )
72+
73+ def test_wds_loader_transform (self ):
74+ it = iter (SeqChromDatasetByWds (["data/test_0.tar.gz" ],
75+ transforms = {"seq" : test_seq_transform ,
76+ "chrom" : test_chrom_transform ,
77+ "target" : test_target_transform },
78+ dataloader_kws = {"batch_size" :3 }))
79+ seq , chrom , target , label = next (it )
80+
81+ self .assertEqual (seq [0 ,0 ,3 ].item (), 2.0 )
82+ self .assertAlmostEqual (chrom [0 ,0 ,3 ].item (), 4.0 / 3 )
83+ self .assertEqual (target [0 ].item (), 0.0 )
84+ self .assertEqual (label [1 ].item (), 1 )
85+
86+ def test_bed_loader (self ):
87+ it = iter (SeqChromDatasetByBed (
88+ bed = "data/sample.bed" ,
89+ genome_fasta = "data/sample.fa" ,
90+ bigwig_filelist = ["data/sample.bw" ],
91+ target_bam = "data/sample.bam" ,
92+ dataloader_kws = {"batch_size" :2 ,
93+ "shuffle" :False }
94+ ))
95+ seq , chrom , target , label = next (it )
96+ self .assertEqual (seq [1 ,0 ,4 ].item (), 1.0 )
97+ self .assertEqual (chrom [0 ,0 ,3 ].item (), 999.0 )
98+ self .assertEqual (target [0 ].item (), 2.0 )
99+ self .assertEqual (label [1 ].item (), 1 )
100+
101+ def test_bed_loader_transform (self ):
102+
103+ it = iter (SeqChromDatasetByBed (
104+ bed = "data/sample.bed" ,
105+ genome_fasta = "data/sample.fa" ,
106+ bigwig_filelist = ["data/sample.bw" ],
107+ target_bam = "data/sample.bam" ,
108+ transforms = {"seq" : test_seq_transform ,
109+ "chrom" : test_chrom_transform ,
110+ "target" : test_target_transform },
111+ dataloader_kws = {"batch_size" :2 ,
112+ "shuffle" :False }
113+ ))
114+ seq , chrom , target , label = next (it )
115+ self .assertEqual (seq [1 ,0 ,4 ].item (), 2.0 )
116+ self .assertEqual (chrom [0 ,0 ,3 ].item (), 333.0 )
117+ self .assertEqual (target [0 ].item (), 6.0 )
118+ self .assertEqual (label [1 ].item (), 1 )
119+
120+ def test_lightning_datamodule (self ):
121+ dm = SeqChromDataModule (
122+ train_wds = "data/test_0.tar.gz" ,
123+ val_wds = "data/test_0.tar.gz" ,
124+ test_wds = "data/test_0.tar.gz" ,
125+ train_dataset_size = 100 ,
126+ batch_size = 3 ,
127+ num_workers = 1 ,
128+ patch_last = False ,
129+ )
130+ dm .setup ()
131+ val_dl = iter (dm .val_dataloader ())
132+ seq , chrom , target , label = next (val_dl )
133+ self .assertEqual (seq [0 ,0 ,3 ].item (), 1.0 )
134+ self .assertEqual (chrom [0 ,0 ,3 ].item (), 4.0 )
135+ self .assertEqual (target [0 ].item (), 0.0 )
136+ self .assertEqual (label [1 ].item (), 1 )
137+
138+ def test_lightning_datamodule_transform (self ):
139+ dm = SeqChromDataModule (
140+ train_wds = "data/test_0.tar.gz" ,
141+ val_wds = "data/test_0.tar.gz" ,
142+ test_wds = "data/test_0.tar.gz" ,
143+ transforms = {"seq" : test_seq_transform ,
144+ "chrom" : test_chrom_transform ,
145+ "target" : test_target_transform },
146+ train_dataset_size = 100 ,
147+ batch_size = 3 ,
148+ num_workers = 1 ,
149+ patch_last = False ,
150+ )
151+ dm .setup ()
152+ val_dl = iter (dm .val_dataloader ())
153+ seq , chrom , target , label = next (val_dl )
154+
155+ self .assertEqual (seq [0 ,0 ,3 ].item (), 2.0 )
156+ self .assertAlmostEqual (chrom [0 ,0 ,3 ].item (), 4.0 / 3 )
157+ self .assertEqual (target [0 ].item (), 0.0 )
158+ self .assertEqual (label [1 ].item (), 1 )
159+
160+ def test_seq_transform (seq ):
161+ return seq + 1
162+
163+ def test_chrom_transform (chrom ):
164+ return chrom / 3
165+
166+ def test_target_transform (target ):
167+ return target * 3
168+
169+ if __name__ == "__main__" :
170+ unittest .main ()
0 commit comments