Skip to content

Commit a9c9de0

Browse files
committed
added dataclasses
1 parent 7cf1b8c commit a9c9de0

2 files changed

Lines changed: 360 additions & 0 deletions

File tree

kipoiseq/dataclasses.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
"""Data classes for major objects:
2+
- Interval
3+
- Variant
4+
"""
5+
from collections import Mapping, OrderedDict
6+
from copy import deepcopy
7+
from kipoi_utils.data_utils import numpy_collate, numpy_collate_concat
8+
import math
9+
# -------------------------------------------
10+
# basepair implementation
11+
import attr
12+
13+
14+
class Variant:
15+
def __init__(self,
16+
chrom: str,
17+
pos: int, # 1-based
18+
ref: str,
19+
alt: str,
20+
id: str = '',
21+
qual: float=0,
22+
filter: str='PASS',
23+
info: dict=None,
24+
source=None):
25+
"""Variant container.
26+
27+
See also VCF file definition: http://samtools.github.io/hts-specs/VCFv4.3.pdf
28+
Note: this class doesn't hold the genotype information.
29+
30+
Args:
31+
chrom: CHROM field in the VCF
32+
pos: POS field in the VCF
33+
ref: REF field in the VCF
34+
alt: ALT field in the VCF
35+
id: ID field in the VCF
36+
qual: QUAL field in the VCF
37+
filter: FILTER field in the VCF
38+
info: INFO field in the VCF
39+
source: reference to the original source object from which this
40+
Variant object was created (e.g. `cyvcf2.Variant()` class)
41+
"""
42+
# main 4 attributes
43+
# these should be immutable by default to not
44+
# run into any strange issues downstream.
45+
self._chrom = chrom
46+
self._pos = pos
47+
self._ref = ref
48+
self._alt = alt
49+
50+
# other 4 main VCF attributes
51+
self.id = id
52+
self.qual = qual
53+
self.filter = filter
54+
self.info = info or dict()
55+
56+
# additional attribute implemented by this class
57+
self.source = source
58+
59+
def copy(self):
60+
return deepcopy(self)
61+
62+
@property
63+
def chrom(self):
64+
return self._chrom
65+
66+
@property
67+
def pos(self):
68+
return self._pos
69+
70+
@property
71+
def ref(self):
72+
return self._ref
73+
74+
@property
75+
def alt(self):
76+
return self._alt
77+
78+
# convenience properties
79+
@property
80+
def start(self):
81+
"""0-based variant start position
82+
"""
83+
return self.pos - 1
84+
85+
@classmethod
86+
def from_cyvcf(cls, obj):
87+
if len(obj.ALT) > 1:
88+
# TODO - do a proper warning
89+
print("WARNING: len(obj.ALT) > 1")
90+
91+
return cls(chrom=obj.CHROM,
92+
pos=obj.POS,
93+
ref=obj.REF,
94+
alt=obj.ALT[0], # note. we are using a single one
95+
id=obj.ID,
96+
qual=obj.QUAL,
97+
filter=obj.FILTER,
98+
info=dict(obj.INFO),
99+
source=obj,
100+
)
101+
102+
def __eq__(self, obj):
103+
return (self.chrom == obj.chrom and
104+
self.pos == obj.pos and
105+
self.ref == obj.ref and
106+
self.alt == obj.alt)
107+
108+
def __hash__(self):
109+
return hash((self.chrom, self.pos, self.ref, self.alt))
110+
111+
def __str__(self):
112+
return f"{self.chrom}_{self.pos}_{self.ref}/{self.alt}"
113+
114+
def from_str(self, s):
115+
chrom, pos, ref_alt = s.split("_")
116+
ref, alt = ref_alt.split("/")
117+
return Variant(chrom=chrom, pos=pos, ref=ref, alt=alt)
118+
119+
def __repr__(self):
120+
return f"Variant(chrom='{self.variant}'), pos={self.pos}, ref='{self.ref}', alt='{self.alt}', id='{self.id}',..."
121+
122+
123+
class Interval:
124+
"""Container for genomic interval(s)
125+
126+
All fields can be either a single values (str or int) or a
127+
numpy array of values.
128+
129+
# Arguments
130+
chrom: Chromosome
131+
start: start position
132+
end: end position
133+
name: interval name
134+
score: interval score
135+
strand: interval strand ("+", "-" or "." for unknown strand)
136+
attrs: additional attributes provided as a dictionary
137+
"""
138+
139+
def __init__(self,
140+
chrom: str,
141+
start: int, # 0-based
142+
end: int, # 0-based
143+
name: str='',
144+
score: float=0,
145+
strand: str='.',
146+
attrs: dict=None):
147+
self._chrom = chrom
148+
self._start = start
149+
self._end = end
150+
self.name = name
151+
self.score = score
152+
self._strand = strand
153+
self.attrs = attrs or dict()
154+
155+
# handle chr and stop
156+
@property
157+
def chrom(self):
158+
return self._chrom
159+
160+
@property
161+
def chr(self):
162+
return self.chrom
163+
164+
@property
165+
def start(self):
166+
return self._start
167+
168+
@property
169+
def end(self):
170+
return self._end
171+
172+
@property
173+
def stop(self):
174+
return self.end
175+
176+
@property
177+
def strand(self):
178+
return self._strand
179+
180+
@classmethod
181+
def from_pybedtools(cls, interval):
182+
"""Create the ranges object from `pybedtools.Interval`
183+
184+
# Arguments
185+
interval: `pybedtools.Interval` instance
186+
"""
187+
return cls(chrom=interval.chrom,
188+
start=interval.start,
189+
end=interval.stop,
190+
name=interval.name,
191+
score=interval.score,
192+
strand=interval.strand,
193+
attrs=dict(interval.attrs or dict())
194+
)
195+
196+
def to_pybedtools(self):
197+
import pybedtools
198+
return pybedtools.create_interval_from_list([self.chrom,
199+
self.start,
200+
self.end,
201+
self.name,
202+
self.score,
203+
self.strand])
204+
205+
@property
206+
def neg_strand(self):
207+
return self.strand == "-"
208+
209+
def center(self, ignore_strand=False):
210+
"""Compute the center of the interval
211+
"""
212+
if ignore_strand:
213+
add_offset = 0
214+
else:
215+
add_offset = 0 if self.neg_strand else 1
216+
delta = (self.end + self.start) % 2
217+
center = (self.end + self.start) // 2
218+
return center + add_offset * delta
219+
220+
def shift(self, x: int, use_strand: bool=False):
221+
"""Shift the interval by x.
222+
223+
Args:
224+
x: shift amount
225+
use_strand (bool)
226+
227+
228+
This will perform:
229+
(chrom, start + x, end + x)
230+
231+
If the strand is negative and use_strand is True, it will return:
232+
(chrom, start - x, end - x)
233+
"""
234+
obj = self.copy()
235+
if use_strand and self.neg_strand:
236+
x = - x
237+
obj._start = self.start + x
238+
obj._end = self.end + x
239+
return obj
240+
241+
def swap_strand(self):
242+
obj = self.copy()
243+
if obj.strand == "+":
244+
obj.strand = "-"
245+
elif obj.strand == "-":
246+
obj.strand = "+"
247+
return obj
248+
249+
def __eq__(self, obj):
250+
return (self.chrom == obj.chrom and
251+
self.start == obj.start and
252+
self.end == obj.end and
253+
self.strand == obj.strand)
254+
255+
def __hash__(self):
256+
return hash((self.chrom, self.start, self.end, self.strand))
257+
258+
def __str__(self):
259+
return (f"{self.chrom}:{self.start}-{self.end}:{self.strand}")
260+
261+
def __repr__(self):
262+
return (f"Interval(chrom='{self.chrom}', start={self.start}, end={self.end}, name='{self.name}', strand='{self.strand}', ...)")
263+
264+
def from_str(self, s):
265+
chrom, int_range, strand = s.split(":")
266+
start, end = int_range.split("-")
267+
return Interval(chrom=chrom,
268+
start=int(start),
269+
end=int(end),
270+
strand=strand)
271+
272+
def copy(self):
273+
return deepcopy(self)
274+
275+
def slop(self, upstream=0, downstream=0, use_strand=False):
276+
"""Extend the interval on each strand
277+
"""
278+
obj = self.copy()
279+
obj._start -= upstream
280+
obj._end += downstream
281+
return obj
282+
283+
def is_valid(self, chrom_len=math.inf):
284+
"""Check if the interval is valid
285+
"""
286+
return self.start >= 0 and self.end < chrom_len
287+
288+
def truncate(self, chrom_len=math.inf):
289+
"""Truncate the interval to become valid
290+
"""
291+
if self.is_valid(chrom_len):
292+
return self
293+
else:
294+
obj = self.copy()
295+
obj._start = max(self._start, 0)
296+
obj._end = min(self.end, chrom_len - 1)
297+
return obj
298+
299+
def resize(self, width):
300+
obj = deepcopy(self)
301+
302+
if width is None or self.width() == width:
303+
# no need to resize
304+
return obj
305+
306+
if self.strand != "-":
307+
# positive strand
308+
obj._start = self.center() - width // 2 - width % 2
309+
obj._end = self.center() + width // 2
310+
else:
311+
obj._start = self.center() - width // 2
312+
obj._end = self.center() + width // 2 + width % 2
313+
return obj
314+
315+
def width(self):
316+
return self.end - self.start
317+
318+
def __len__(self):
319+
return self.width()
320+
321+
def trim(self, i, j):
322+
if i == 0 and j == self.width():
323+
return self
324+
obj = self.copy()
325+
assert j > i
326+
if self.strand == "-":
327+
w = self.width()
328+
obj._start = self.start + w - j
329+
obj._end = self.start + w - i
330+
else:
331+
obj._start = self.start + i
332+
obj._end = self.start + j
333+
return obj

tests/test_dataclasses.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Test kipoiseq.dataclasses
2+
3+
Tests to perform:
4+
5+
# Both
6+
- make sure the immutable objects are really immutable
7+
- str() and from_str
8+
- =, hashable
9+
-
10+
# VCF
11+
- from cyvcf2
12+
13+
# Interval
14+
- validation of the interval
15+
- access to all the attributes
16+
- from_pybedtools and to_pybedtools
17+
- shift, swapt_strand, trim, etc
18+
"""
19+
20+
21+
from kipoiseq.dataclasses import Variant, Interval
22+
23+
def test_variant():
24+
v = Variant("chr1", 10, 'C', 'T')
25+
26+
#
27+
pass

0 commit comments

Comments
 (0)