@@ -88,13 +88,9 @@ def extract(self, interval, variants, anchor, fixed_len=True):
8888 which to query the sequence. 0-based
8989 variants List[cyvcf2.Variant]: variants overlapping the `interval`.
9090 can also be indels. 1-based
91- anchor: position w.r.t. the interval start. (0-based). E.g.
92- for an interval of `chr1:10-20` the anchor of 0 denotes
93- the point chr1:10 in the 0-based coordinate system. Similarly,
94- `anchor=5` means the anchor point is right in the middle
95- of the sequence e.g. first half of the sequence (5nt) will be
96- upstream of the anchor and the second half (5nt) will be
97- downstream of the anchor.
91+ anchor: absolution position w.r.t. the interval start. (0-based).
92+ E.g. for an interval of `chr1:10-20` the anchor of 10 denotes
93+ the point chr1:10 in the 0-based coordinate system.
9894 fixed_len: if True, the return sequence will have the same length
9995 as the `interval` (e.g. `interval.end - interval.start`)
10096
@@ -106,7 +102,16 @@ def extract(self, interval, variants, anchor, fixed_len=True):
106102 variant_pairs = self ._variant_to_sequence (variants )
107103
108104 # 1. Split variants overlapping with anchor
109- variant_pairs = list (self ._split_overlapping (variant_pairs , anchor ))
105+ # and interval start end if not fixed_len
106+ variant_pairs = self ._split_overlapping (variant_pairs , anchor )
107+
108+ if not fixed_len :
109+ variant_pairs = self ._split_overlapping (
110+ variant_pairs , interval .start , which = 'right' )
111+ variant_pairs = self ._split_overlapping (
112+ variant_pairs , interval .end , which = 'left' )
113+
114+ variant_pairs = list (variant_pairs )
110115
111116 # 2. split the variants into upstream and downstream
112117 # and sort the variants in each interval
@@ -168,15 +173,17 @@ def _variant_to_sequence(self, variants):
168173 start = v .start , end = v .start + len (v .ALT [0 ]))
169174 yield ref , alt
170175
171- def _split_overlapping (self , variant_pairs , anchor ):
176+ def _split_overlapping (self , variant_pairs , anchor , which = 'both' ):
172177 """
173178 Split the variants hitting the anchor into two
174179 """
175180 for ref , alt in variant_pairs :
176181 if ref .start < anchor < ref .end or alt .start < anchor < alt .end :
177182 mid = anchor - ref .start
178- yield ref [:mid ], alt [:mid ]
179- yield ref [mid :], alt [mid :]
183+ if which == 'left' or which == 'both' :
184+ yield ref [:mid ], alt [:mid ]
185+ if which == 'right' or which == 'both' :
186+ yield ref [mid :], alt [mid :]
180187 else :
181188 yield ref , alt
182189
@@ -201,7 +208,7 @@ def _downstream_builder(self, down_variants, interval, anchor, istart):
201208
202209 prev = anchor
203210 for ref , alt in down_variants :
204- if ref .end <= istart :
211+ if ref .end < istart :
205212 break
206213 down_sb .append (Interval (interval .chrom , ref .end , prev ))
207214 down_sb .append (alt )
@@ -239,14 +246,28 @@ def _cut_to_fix_len(self, down_str, up_str, interval, anchor):
239246
240247
241248class BaseVCFSeqExtractor (BaseExtractor ):
249+ """
250+ Base class to fetch sequence in which variants applied based
251+ on given vcf file.
252+ """
253+
242254 def __init__ (self , fasta_file , vcf_file ):
255+ """
256+ Args:
257+ fasta_file: path to the fasta file (can be gzipped)
258+ vcf_file: path to the fasta file (need be bgzipped and indexed)
259+ """
243260 self .fasta_file = fasta_file
244261 self .vcf_file = vcf_file
245262 self .variant_extractor = VariantSeqExtractor (fasta_file )
246263 self .vcf = MultiSampleVCF (vcf_file )
247264
248265
249266class SingleVariantVCFSeqExtractor (BaseVCFSeqExtractor ):
267+ """
268+ Fetch list of sequence in which each variant applied based
269+ on given vcf file.
270+ """
250271
251272 def extract (self , interval , anchor = None , sample_id = None , fixed_len = True ):
252273 for variant in self .vcf .fetch_variants (interval , sample_id ):
@@ -257,6 +278,9 @@ def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
257278
258279
259280class SingleSeqVCFSeqExtractor (BaseVCFSeqExtractor ):
281+ """
282+ Fetch sequence in which all variant applied based on given vcf file.
283+ """
260284
261285 def extract (self , interval , anchor = None , sample_id = None , fixed_len = True ):
262286 return self .variant_extractor .extract (
0 commit comments