Skip to content

Commit 5f84634

Browse files
committed
Add script for predicting on bed file input
1 parent 3a46a1b commit 5f84634

1 file changed

Lines changed: 109 additions & 0 deletions

File tree

trainNN/predict_bed.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#!python3
2+
3+
import os
4+
import argparse
5+
import numpy as np
6+
import pandas as pd
7+
8+
import pyfasta
9+
import pyBigWig
10+
import logging
11+
12+
from tensorflow.keras.models import load_model
13+
14+
def dna2onehot(dnaSeq):
15+
DNA2index = {
16+
"A": 0,
17+
"T": 1,
18+
"G": 2,
19+
"C": 3
20+
}
21+
22+
seqLen = len(dnaSeq)
23+
24+
# initialize the matrix to seqlen x 4
25+
seqMatrixs = np.zeros((seqLen,4), dtype=int)
26+
# change the value to matrix
27+
dnaSeq = dnaSeq.upper()
28+
for j in range(0,seqLen):
29+
try:
30+
seqMatrixs[j, DNA2index[dnaSeq[j]]] = 1
31+
except KeyError as e:
32+
continue
33+
return seqMatrixs
34+
35+
def get_data(chunk, genome_pyfasta, bigwigs, nbins):
36+
seqs = []
37+
ms = []
38+
for item in chunk.itertuples():
39+
# get seq info
40+
seq = genome_pyfasta[item.chrom][int(item.start):int(item.end)]
41+
seq_onehot = dna2onehot(seq)
42+
43+
# get chrom info
44+
try:
45+
for idx, bigwig in enumerate(bigwigs):
46+
m = (np.nan_to_num(bigwig.values(item.chrom, item.start, item.end))
47+
.reshape((nbins, -1))
48+
.mean(axis=1, dtype=float))
49+
except RuntimeError as e:
50+
logging.warning(e)
51+
logging.warning(f"Chromatin track doesn't have information in {item} Skip this region...")
52+
raise e
53+
54+
# store
55+
seqs.append(seq_onehot); ms.append(m)
56+
57+
seqs = np.stack(seqs); ms = np.stack(ms)
58+
return {"seq": seqs, "chrom_input": ms}
59+
60+
def predict_generator(bed_file, fasta, bigwig_files, nbins, batchsize=128):
61+
"""
62+
Generator that iterate through the bed file until the end
63+
"""
64+
bed_chunks = pd.read_table(bed_file, header=None, usecols=[0, 1, 2], names=["chrom", "start", "end"], chunksize=batchsize)
65+
genome_pyfasta = pyfasta.Fasta(fasta)
66+
bigwigs = [pyBigWig.open(bw) for bw in bigwig_files]
67+
68+
for chunk in bed_chunks:
69+
try:
70+
input = get_data(chunk, genome_pyfasta, bigwigs, nbins)
71+
except RuntimeError as e:
72+
continue
73+
yield input
74+
75+
def main():
76+
parser = argparse.ArgumentParser(description='Use Bichrom model for prediction given bed file')
77+
parser.add_argument('-mseq', required=True,
78+
help='Sequence Model')
79+
parser.add_argument('-msc', required=True,
80+
help='Bichrom Model')
81+
parser.add_argument('-fa', help='The fasta file for the genome of interest', required=True)
82+
parser.add_argument('-chromtracks', nargs='+', help='A list of BigWig files for all input chromatin experiments, please follow the same order of training data', required=True)
83+
parser.add_argument('-nbins', type=int, help='Number of bins for chromatin tracks', required=True)
84+
parser.add_argument('-prefix', required=True, help='Output prefix')
85+
parser.add_argument('-bed', required=True, help='bed file describing region used for prediction')
86+
args = parser.parse_args()
87+
88+
mseq = load_model(args.mseq)
89+
msc = load_model(args.msc)
90+
pred_dataset = predict_generator(args.bed, args.fa, args.chromtracks, args.nbins)
91+
92+
# get predictions
93+
mseq_probs = []; msc_probs = []
94+
for input in pred_dataset:
95+
mseq_prob = mseq(input, training=False)
96+
msc_prob = msc(input, training=False)
97+
mseq_probs.append(mseq_prob)
98+
msc_probs.append(msc_prob)
99+
mseq_probs = np.concatenate(mseq_probs)
100+
msc_probs = np.concatenate(msc_probs)
101+
102+
# save to file
103+
with open(args.prefix + "mseq_prob.txt", "w") as fmseq, open(args.prefix + "msc_prob.txt", "w") as fmsc:
104+
np.savetxt(fmseq, mseq_probs)
105+
np.savetxt(fmsc, msc_probs)
106+
107+
108+
if __name__ == "__main__":
109+
main()

0 commit comments

Comments
 (0)