|
15 | 15 | from os import makedirs, path |
16 | 16 |
|
17 | 17 | import numpy as np |
| 18 | +import pandas as pd |
18 | 19 | import seqchromloader as scl |
19 | 20 | import torch |
20 | 21 | import utils |
21 | | -from Bio import motifs |
| 22 | +from Bio import motifs as Bio_motifs |
22 | 23 | from captum.concept import Concept |
23 | 24 | from run_tcav_sgd_pca import ( |
24 | 25 | construct_motif_concept_dataloader_from_control, |
@@ -83,6 +84,11 @@ def pair(arg): |
83 | 84 | type=int, |
84 | 85 | help="Number of motifs to insert per bin", |
85 | 86 | ) |
| 87 | + parser.add_argument( |
| 88 | + "--include-reverse-complement", |
| 89 | + action="store_true", |
| 90 | + help="Use both forward and reverse complement version of motifs", |
| 91 | + ) |
86 | 92 | parser.add_argument( |
87 | 93 | "--bin-edges", |
88 | 94 | required=True, |
@@ -165,46 +171,66 @@ def pair(arg): |
165 | 171 | ## custom motifs, use the first control concept as a template |
166 | 172 | ### according to # bins, get the list of buffer |
167 | 173 | breakpoints = [0, *args.bin_edges, args.input_window_length] |
168 | | - bin_list = [(breakpoints[i], breakpoints[i+1]) for i in range(len(breakpoints)-1)] |
| 174 | + bin_list = [ |
| 175 | + (breakpoints[i], breakpoints[i + 1]) for i in range(len(breakpoints) - 1) |
| 176 | + ] |
169 | 177 |
|
170 | 178 | if args.custom_motifs is not None: |
171 | | - with open(args.custom_motifs) as f: |
172 | | - |
173 | | - for m in f: |
174 | | - for buffer_idx, (bin_start, bin_end) in enumerate(bin_list): |
175 | | - start_buffer = bin_start |
176 | | - end_buffer = args.input_window_length - bin_end |
177 | | - motif_name, consensus_seq = m.strip().split("\t") |
178 | | - motif = utils.CustomMotif("motif", consensus_seq) |
179 | | - cn = f"{motif_name}_bin_start_{bin_start}_end_{bin_end}" |
180 | | - seq_dl = construct_motif_concept_dataloader_from_control( |
181 | | - random_regions_df, |
182 | | - args.genome_fasta_file, |
183 | | - motif=motif, |
184 | | - num_motifs=args.num_motifs, |
185 | | - start_buffer=start_buffer, |
186 | | - end_buffer=end_buffer, |
187 | | - batch_size=BATCH_SIZE, |
188 | | - ) |
189 | | - concepts.append( |
190 | | - Concept( |
191 | | - id=idx, |
192 | | - name=cn, |
193 | | - data_iter=zip(seq_dl, control_chrom_dl), |
194 | | - ) |
| 179 | + df = pd.read_table(args.custom_motifs, names=["motif_name", "consensus_seq"]) |
| 180 | + for m in np.unique(df.motif_name): |
| 181 | + motif_name = m |
| 182 | + consensus_seqs = df.loc[ |
| 183 | + df.motif_name == m, "consensus_seq" |
| 184 | + ].tolist() # take all consensus seqs that correspond to the same motif name |
| 185 | + motifs = [] |
| 186 | + for i, c in enumerate(consensus_seqs): |
| 187 | + motif = utils.CustomMotif(f"{m}_{i}", c) |
| 188 | + motifs.append(motif) |
| 189 | + if ( |
| 190 | + args.include_reverse_complement |
| 191 | + ): # add reverse complement if specified |
| 192 | + motif_rc = motif.reverse_complement() |
| 193 | + motifs.append(motif_rc) |
| 194 | + |
| 195 | + for buffer_idx, (bin_start, bin_end) in enumerate(bin_list): |
| 196 | + start_buffer = bin_start |
| 197 | + end_buffer = args.input_window_length - bin_end |
| 198 | + cn = f"{motif_name}_bin_start_{bin_start}_end_{bin_end}" |
| 199 | + seq_dl = construct_motif_concept_dataloader_from_control( |
| 200 | + random_regions_df, |
| 201 | + args.genome_fasta_file, |
| 202 | + motifs=motifs, |
| 203 | + num_motifs=args.num_motifs, |
| 204 | + start_buffer=start_buffer, |
| 205 | + end_buffer=end_buffer, |
| 206 | + batch_size=BATCH_SIZE, |
| 207 | + ) |
| 208 | + concepts.append( |
| 209 | + Concept( |
| 210 | + id=idx, |
| 211 | + name=cn, |
| 212 | + data_iter=zip(seq_dl, control_chrom_dl), |
195 | 213 | ) |
196 | | - idx += 1 |
| 214 | + ) |
| 215 | + idx += 1 |
| 216 | + |
197 | 217 | if args.meme_motifs is not None: |
198 | 218 | with open(args.meme_motifs) as f: |
199 | | - for motif in motifs.parse(f, fmt="MINIMAL"): |
| 219 | + for motif in Bio_motifs.parse(f, fmt="MINIMAL"): |
| 220 | + motifs = [] |
| 221 | + motifs.append(motif) |
| 222 | + if args.include_reverse_complement: |
| 223 | + motif_rc = motif.reverse_complement() |
| 224 | + motifs.append(motif_rc) |
| 225 | + |
200 | 226 | for buffer_idx, (bin_start, bin_end) in enumerate(bin_list): |
201 | 227 | start_buffer = bin_start |
202 | 228 | end_buffer = args.input_window_length - bin_end |
203 | 229 | cn = f"{motif.name.replace('/', '-')}_bin_start_{bin_start}_end_{bin_end}" |
204 | 230 | seq_dl = construct_motif_concept_dataloader_from_control( |
205 | 231 | random_regions_df, |
206 | 232 | args.genome_fasta_file, |
207 | | - motif=motif, |
| 233 | + motifs=motifs, |
208 | 234 | num_motifs=args.num_motifs, |
209 | 235 | start_buffer=start_buffer, |
210 | 236 | end_buffer=end_buffer, |
|
0 commit comments