Skip to content

Commit d30eb1b

Browse files
author
Muhammed Hasan
authored
Merge pull request #33 from MuhammedHasan/master
Split variant based on interval start-end if fixed_len, fixes #32
2 parents 4ee87d5 + 891e793 commit d30eb1b

2 files changed

Lines changed: 43 additions & 15 deletions

File tree

kipoiseq/extractors/vcf_seq.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,9 @@ def extract(self, interval, variants, anchor, fixed_len=True):
8888
which to query the sequence. 0-based
8989
variants List[cyvcf2.Variant]: variants overlapping the `interval`.
9090
can also be indels. 1-based
91-
anchor: position w.r.t. the interval start. (0-based). E.g.
92-
for an interval of `chr1:10-20` the anchor of 0 denotes
93-
the point chr1:10 in the 0-based coordinate system. Similarly,
94-
`anchor=5` means the anchor point is right in the middle
95-
of the sequence e.g. first half of the sequence (5nt) will be
96-
upstream of the anchor and the second half (5nt) will be
97-
downstream of the anchor.
91+
anchor: absolution position w.r.t. the interval start. (0-based).
92+
E.g. for an interval of `chr1:10-20` the anchor of 10 denotes
93+
the point chr1:10 in the 0-based coordinate system.
9894
fixed_len: if True, the return sequence will have the same length
9995
as the `interval` (e.g. `interval.end - interval.start`)
10096
@@ -106,7 +102,16 @@ def extract(self, interval, variants, anchor, fixed_len=True):
106102
variant_pairs = self._variant_to_sequence(variants)
107103

108104
# 1. Split variants overlapping with anchor
109-
variant_pairs = list(self._split_overlapping(variant_pairs, anchor))
105+
# and interval start end if not fixed_len
106+
variant_pairs = self._split_overlapping(variant_pairs, anchor)
107+
108+
if not fixed_len:
109+
variant_pairs = self._split_overlapping(
110+
variant_pairs, interval.start, which='right')
111+
variant_pairs = self._split_overlapping(
112+
variant_pairs, interval.end, which='left')
113+
114+
variant_pairs = list(variant_pairs)
110115

111116
# 2. split the variants into upstream and downstream
112117
# and sort the variants in each interval
@@ -168,15 +173,17 @@ def _variant_to_sequence(self, variants):
168173
start=v.start, end=v.start + len(v.ALT[0]))
169174
yield ref, alt
170175

171-
def _split_overlapping(self, variant_pairs, anchor):
176+
def _split_overlapping(self, variant_pairs, anchor, which='both'):
172177
"""
173178
Split the variants hitting the anchor into two
174179
"""
175180
for ref, alt in variant_pairs:
176181
if ref.start < anchor < ref.end or alt.start < anchor < alt.end:
177182
mid = anchor - ref.start
178-
yield ref[:mid], alt[:mid]
179-
yield ref[mid:], alt[mid:]
183+
if which == 'left' or which == 'both':
184+
yield ref[:mid], alt[:mid]
185+
if which == 'right' or which == 'both':
186+
yield ref[mid:], alt[mid:]
180187
else:
181188
yield ref, alt
182189

@@ -201,7 +208,7 @@ def _downstream_builder(self, down_variants, interval, anchor, istart):
201208

202209
prev = anchor
203210
for ref, alt in down_variants:
204-
if ref.end <= istart:
211+
if ref.end < istart:
205212
break
206213
down_sb.append(Interval(interval.chrom, ref.end, prev))
207214
down_sb.append(alt)
@@ -239,14 +246,28 @@ def _cut_to_fix_len(self, down_str, up_str, interval, anchor):
239246

240247

241248
class BaseVCFSeqExtractor(BaseExtractor):
249+
"""
250+
Base class to fetch sequence in which variants applied based
251+
on given vcf file.
252+
"""
253+
242254
def __init__(self, fasta_file, vcf_file):
255+
"""
256+
Args:
257+
fasta_file: path to the fasta file (can be gzipped)
258+
vcf_file: path to the fasta file (need be bgzipped and indexed)
259+
"""
243260
self.fasta_file = fasta_file
244261
self.vcf_file = vcf_file
245262
self.variant_extractor = VariantSeqExtractor(fasta_file)
246263
self.vcf = MultiSampleVCF(vcf_file)
247264

248265

249266
class SingleVariantVCFSeqExtractor(BaseVCFSeqExtractor):
267+
"""
268+
Fetch list of sequence in which each variant applied based
269+
on given vcf file.
270+
"""
250271

251272
def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
252273
for variant in self.vcf.fetch_variants(interval, sample_id):
@@ -257,6 +278,9 @@ def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
257278

258279

259280
class SingleSeqVCFSeqExtractor(BaseVCFSeqExtractor):
281+
"""
282+
Fetch sequence in which all variant applied based on given vcf file.
283+
"""
260284

261285
def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
262286
return self.variant_extractor.extract(

tests/extractors/test_vcf_seq_extractor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,14 @@ def test__split_overlapping(variant_seq_extractor):
9292

9393

9494
def test_extract(variant_seq_extractor):
95-
interval = Interval('chr1', 2, 9)
96-
9795
variants = list(VCF(vcf_file)())
96+
97+
interval = Interval('chr1', 2, 9)
9898
seq = variant_seq_extractor.extract(interval, variants, anchor=5)
9999
assert len(seq) == interval.end - interval.start
100100
assert seq == 'GCGAACG'
101101

102102
interval = Interval('chr1', 2, 9, strand='-')
103-
variants = list(VCF(vcf_file)())
104103
seq = variant_seq_extractor.extract(interval, variants, anchor=5)
105104
assert len(seq) == interval.end - interval.start
106105
assert seq == 'CGTTCGC'
@@ -140,6 +139,11 @@ def test_extract(variant_seq_extractor):
140139
assert len(seq) == interval.end - interval.start
141140
assert seq == 'AACGTAACGT'
142141

142+
interval = Interval('chr1', 5, 11, strand='+')
143+
seq = variant_seq_extractor.extract(
144+
interval, variants, anchor=10, fixed_len=False)
145+
assert seq == 'AACGTAA'
146+
143147

144148
@pytest.fixture
145149
def single_variant_vcf_seq_extractor():

0 commit comments

Comments
 (0)