Skip to content

Commit e8950bf

Browse files
committed
added tests
1 parent a9c9de0 commit e8950bf

2 files changed

Lines changed: 197 additions & 35 deletions

File tree

kipoiseq/dataclasses.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def __init__(self,
4242
# main 4 attributes
4343
# these should be immutable by default to not
4444
# run into any strange issues downstream.
45-
self._chrom = chrom
46-
self._pos = pos
47-
self._ref = ref
48-
self._alt = alt
45+
self._chrom = str(chrom)
46+
self._pos = int(pos)
47+
self._ref = str(ref)
48+
self._alt = str(alt)
4949

5050
# other 4 main VCF attributes
5151
self.id = id
@@ -109,15 +109,16 @@ def __hash__(self):
109109
return hash((self.chrom, self.pos, self.ref, self.alt))
110110

111111
def __str__(self):
112-
return f"{self.chrom}_{self.pos}_{self.ref}/{self.alt}"
112+
return f"{self.chrom}:{self.pos}:{self.ref}>{self.alt}"
113113

114-
def from_str(self, s):
115-
chrom, pos, ref_alt = s.split("_")
116-
ref, alt = ref_alt.split("/")
117-
return Variant(chrom=chrom, pos=pos, ref=ref, alt=alt)
114+
@classmethod
115+
def from_str(cls, s):
116+
chrom, pos, ref_alt = s.split(":")
117+
ref, alt = ref_alt.split(">")
118+
return cls(chrom=chrom, pos=int(pos), ref=ref, alt=alt)
118119

119120
def __repr__(self):
120-
return f"Variant(chrom='{self.variant}'), pos={self.pos}, ref='{self.ref}', alt='{self.alt}', id='{self.id}',..."
121+
return f"Variant(chrom='{self.chrom}', pos={self.pos}, ref='{self.ref}', alt='{self.alt}', id='{self.id}',...)"
121122

122123

123124
class Interval:
@@ -206,18 +207,18 @@ def to_pybedtools(self):
206207
def neg_strand(self):
207208
return self.strand == "-"
208209

209-
def center(self, ignore_strand=False):
210+
def center(self, use_strand=True):
210211
"""Compute the center of the interval
211212
"""
212-
if ignore_strand:
213-
add_offset = 0
214-
else:
213+
if use_strand:
215214
add_offset = 0 if self.neg_strand else 1
215+
else:
216+
add_offset = 0
216217
delta = (self.end + self.start) % 2
217218
center = (self.end + self.start) // 2
218219
return center + add_offset * delta
219220

220-
def shift(self, x: int, use_strand: bool=False):
221+
def shift(self, x: int, use_strand: bool=True):
221222
"""Shift the interval by x.
222223
223224
Args:
@@ -241,9 +242,9 @@ def shift(self, x: int, use_strand: bool=False):
241242
def swap_strand(self):
242243
obj = self.copy()
243244
if obj.strand == "+":
244-
obj.strand = "-"
245+
obj._strand = "-"
245246
elif obj.strand == "-":
246-
obj.strand = "+"
247+
obj._strand = "+"
247248
return obj
248249

249250
def __eq__(self, obj):
@@ -261,18 +262,19 @@ def __str__(self):
261262
def __repr__(self):
262263
return (f"Interval(chrom='{self.chrom}', start={self.start}, end={self.end}, name='{self.name}', strand='{self.strand}', ...)")
263264

264-
def from_str(self, s):
265+
@classmethod
266+
def from_str(cls, s):
265267
chrom, int_range, strand = s.split(":")
266268
start, end = int_range.split("-")
267-
return Interval(chrom=chrom,
268-
start=int(start),
269-
end=int(end),
270-
strand=strand)
269+
return cls(chrom=chrom,
270+
start=int(start),
271+
end=int(end),
272+
strand=strand)
271273

272274
def copy(self):
273275
return deepcopy(self)
274276

275-
def slop(self, upstream=0, downstream=0, use_strand=False):
277+
def slop(self, upstream=0, downstream=0, use_strand=True):
276278
"""Extend the interval on each strand
277279
"""
278280
obj = self.copy()
@@ -296,20 +298,21 @@ def truncate(self, chrom_len=math.inf):
296298
obj._end = min(self.end, chrom_len - 1)
297299
return obj
298300

299-
def resize(self, width):
301+
def resize(self, width, use_strand=True):
300302
obj = deepcopy(self)
301303

302304
if width is None or self.width() == width:
303305
# no need to resize
304306
return obj
305307

306-
if self.strand != "-":
308+
if self.neg_strand and use_strand:
309+
# negative strand
310+
obj._start = self.center() - width // 2
311+
obj._end = self.center() + width // 2 + width % 2
312+
else:
307313
# positive strand
308314
obj._start = self.center() - width // 2 - width % 2
309315
obj._end = self.center() + width // 2
310-
else:
311-
obj._start = self.center() - width // 2
312-
obj._end = self.center() + width // 2 + width % 2
313316
return obj
314317

315318
def width(self):
@@ -318,12 +321,12 @@ def width(self):
318321
def __len__(self):
319322
return self.width()
320323

321-
def trim(self, i, j):
324+
def trim(self, i, j, use_strand=True):
322325
if i == 0 and j == self.width():
323326
return self
324327
obj = self.copy()
325328
assert j > i
326-
if self.strand == "-":
329+
if self.strand == "-" and use_strand:
327330
w = self.width()
328331
obj._start = self.start + w - j
329332
obj._end = self.start + w - i

tests/test_dataclasses.py

Lines changed: 164 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,179 @@
88
- =, hashable
99
-
1010
# VCF
11-
- from cyvcf2
11+
- from cyvcf2
1212
1313
# Interval
1414
- validation of the interval
1515
- access to all the attributes
1616
- from_pybedtools and to_pybedtools
1717
- shift, swapt_strand, trim, etc
1818
"""
19-
20-
2119
from kipoiseq.dataclasses import Variant, Interval
20+
import pytest
21+
import cyvcf2
22+
import pybedtools
23+
2224

2325
def test_variant():
2426
v = Variant("chr1", 10, 'C', 'T')
2527

26-
#
27-
pass
28+
assert v.start == 9
29+
assert v.chrom == 'chr1'
30+
assert v.pos == 10
31+
assert v.ref == 'C'
32+
assert v.alt == 'T'
33+
assert isinstance(v.info, dict)
34+
assert len(v.info) == 0
35+
assert v.qual == 0
36+
assert v.filter == 'PASS'
37+
v.info['test'] = 10
38+
assert v.info['test'] == 10
39+
assert isinstance(str(v), str)
40+
41+
# make sure the original got unchangd
42+
v2 = v.copy()
43+
v.info['test'] = 20
44+
assert v2.info['test'] == 10
45+
v.__repr__()
46+
47+
# __str__, from_str
48+
assert v == Variant.from_str(str(v))
49+
50+
# hash test
51+
assert isinstance(hash(v), int)
52+
assert hash(v) == hash(Variant.from_str(str(v)))
53+
54+
# fixed arguments
55+
with pytest.raises(AttributeError):
56+
v.chrom = 'asd'
57+
with pytest.raises(AttributeError):
58+
v.pos = 10
59+
with pytest.raises(AttributeError):
60+
v.ref = 'asd'
61+
with pytest.raises(AttributeError):
62+
v.alt = 'asd'
63+
64+
# non-fixed arguments
65+
v.id = 'asd'
66+
v.qual = 10
67+
v.filter = 'asd'
68+
v.source = 2
69+
70+
assert isinstance(Variant("chr1", '10', 'C', 'T').pos, int)
71+
72+
# from cyvcf2
73+
vcf = cyvcf2.VCF('tests/data/test.vcf.gz')
74+
cv = list(vcf)[0]
75+
76+
v2 = Variant.from_cyvcf(cv)
77+
assert isinstance(v2.source, cyvcf2.Variant)
78+
79+
80+
def test_interval():
81+
interval = Interval("chr1", 10, 20, strand='-')
82+
interval.__repr__()
83+
84+
assert interval.start == 10
85+
assert interval.end == 20
86+
assert interval.chrom == 'chr1'
87+
assert interval.name == ''
88+
assert isinstance(interval.attrs, dict)
89+
assert len(interval.attrs) == 0
90+
interval.attrs['test'] = 10
91+
assert interval.attrs['test'] == 10
92+
assert isinstance(str(interval), str)
93+
assert interval.neg_strand
94+
95+
assert interval.width() == 10
96+
assert len(interval) == 10
97+
98+
# __str__, from_str
99+
assert interval == Interval.from_str(str(interval))
100+
101+
# make sure the original got unchangd
102+
i2 = interval.copy()
103+
interval.attrs['test'] = 20
104+
assert i2.attrs['test'] == 10
105+
106+
# hash test
107+
assert isinstance(hash(interval), int)
108+
assert hash(interval) == hash(Interval.from_str(str(interval)))
109+
110+
# fixed arguments
111+
with pytest.raises(AttributeError):
112+
interval.chrom = 'asd'
113+
with pytest.raises(AttributeError):
114+
interval.start = 10
115+
with pytest.raises(AttributeError):
116+
interval.end = 300
117+
with pytest.raises(AttributeError):
118+
interval.strand = '+'
119+
assert interval.strand == '-'
120+
121+
# non-fixed arguments
122+
interval.name = 'asd'
123+
interval.score = 10
124+
125+
assert interval == Interval.from_pybedtools(interval.to_pybedtools())
126+
assert isinstance(interval.to_pybedtools(), pybedtools.Interval)
127+
128+
i2 = interval.shift(10, use_strand=False)
129+
130+
# original unchanged
131+
assert interval.start == 10
132+
assert interval.end == 20
133+
134+
assert i2.start == 20
135+
assert i2.end == 30
136+
137+
i2 = interval.shift(10) # use_strand = True by default
138+
assert i2.start == 0
139+
assert i2.end == 10
140+
141+
assert not interval.shift(20, use_strand=True).is_valid()
142+
143+
i2 = interval.shift(15, use_strand=True).truncate()
144+
assert i2.start == 0
145+
assert i2.end == 5
146+
147+
assert interval.center() == 15
148+
149+
# resize
150+
i2 = interval.resize(11)
151+
assert i2.start == 10 and i2.end == 21
152+
153+
i2 = interval.resize(12)
154+
assert i2.start == 9 and i2.end == 21
155+
156+
i2 = interval.resize(9)
157+
assert i2.start == 11 and i2.end == 20
158+
159+
i2 = interval.swap_strand()
160+
assert interval.strand == "-"
161+
assert i2.strand == "+"
162+
assert i2.strand == '+'
163+
assert len(i2) == 10
164+
i2 = i2.resize(11)
165+
assert i2.start == 9 and i2.end == 20
166+
167+
i2 = interval.copy()
168+
assert i2.center(use_strand=True) == 15
169+
assert i2.center(use_strand=False) == 15
170+
171+
i2 = interval.swap_strand()
172+
assert i2.strand == "+"
173+
174+
i3 = i2.resize(11).shift(1)
175+
assert i3.center(use_strand=True) == 16
176+
assert i3.center(use_strand=False) == 15
177+
178+
i3 = i2.resize(11).shift(1)
179+
assert i3.center(use_strand=True) == 16
180+
assert i3.center(use_strand=False) == 15
181+
182+
i2 = interval.trim(1, 10)
183+
assert i2.start == 10 and i2.end == 19
184+
185+
i2 = interval.trim(1, 10, use_strand=False)
186+
assert i2.start == 11 and i2.end == 20

0 commit comments

Comments
 (0)