Skip to content

Commit b582606

Browse files
committed
add option for including motif rc version
1 parent 83baaf4 commit b582606

2 files changed

Lines changed: 47 additions & 16 deletions

File tree

scripts/run_tcav_sgd_pca.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ def make_predictions_and_save(avdl, name):
182182
def construct_motif_concept_dataloader_from_control(
183183
control_seq_bed_df,
184184
genome_fasta,
185-
motifs:list,
185+
motifs: list,
186186
num_motifs=128,
187-
motif_mode='pwm',
187+
motif_mode="pwm",
188188
start_buffer=0,
189189
end_buffer=0,
190190
batch_size=8,
@@ -343,6 +343,11 @@ def pair(arg):
343343
parser.add_argument(
344344
"--num-motifs", type=int, default=12, help="Number of motifs to insert"
345345
)
346+
parser.add_argument(
347+
"--include-reverse-complement",
348+
action="store_true",
349+
help="Use both forward and reverse complement version of motifs",
350+
)
346351
parser.add_argument(
347352
"--num-samples-per-concept",
348353
type=int,
@@ -422,21 +427,29 @@ def pair(arg):
422427
concepts = []
423428
## custom motifs, use the first control concept as a template
424429
if args.custom_motifs is not None:
425-
df = pd.read_table(args.custom_motifs, names=['motif_name', 'consensus_seq'])
430+
df = pd.read_table(args.custom_motifs, names=["motif_name", "consensus_seq"])
426431
for m in np.unique(df.motif_name):
427432
motif_name = m
428-
consensus_seqs = df.loc[df.motif_name==m, 'consensus_seq'].tolist() # take all consensus seqs that correspond to the same motif name
433+
consensus_seqs = df.loc[
434+
df.motif_name == m, "consensus_seq"
435+
].tolist() # take all consensus seqs that correspond to the same motif name
429436
motifs = []
430437
for i, c in enumerate(consensus_seqs):
431438
motif = utils.CustomMotif(f"{m}_{i}", c)
432439
motifs.append(motif)
440+
if (
441+
args.include_reverse_complement
442+
): # add reverse complement if specified
443+
motif_rc = motif.reverse_complement()
444+
motifs.append(motif_rc)
445+
433446
cn = f"{motif_name}"
434447
seq_dl = construct_motif_concept_dataloader_from_control(
435448
random_regions_df,
436449
args.genome_fasta_file,
437450
motifs=motifs,
438451
num_motifs=args.num_motifs,
439-
motif_mode='consensus',
452+
motif_mode="consensus",
440453
batch_size=BATCH_SIZE,
441454
infinite=False,
442455
)
@@ -451,13 +464,18 @@ def pair(arg):
451464
if args.meme_motifs is not None:
452465
with open(args.meme_motifs) as f:
453466
for motif in Bio_motifs.parse(f, fmt="MINIMAL"):
467+
motifs = []
468+
motifs.append(motif)
469+
if args.include_reverse_complement:
470+
motif_rc = motif.reverse_complement()
471+
motifs.append(motif_rc)
454472
cn = f"{motif.name.replace('/', '-')}"
455473
seq_dl = construct_motif_concept_dataloader_from_control(
456474
random_regions_df,
457475
args.genome_fasta_file,
458-
motifs=[motif],
476+
motifs=motifs,
459477
num_motifs=args.num_motifs,
460-
motif_mode='pwm',
478+
motif_mode="pwm",
461479
batch_size=BATCH_SIZE,
462480
infinite=False,
463481
)
@@ -528,7 +546,7 @@ def pair(arg):
528546
logger.info(concepts)
529547

530548
# register hook
531-
def get_activation(concept, num_samples=10):
549+
def get_activation(concept, num_samples=args.num_samples_per_concept):
532550
avs = []
533551
num = 0
534552
for seq, chrom in concept.data_iter:
@@ -596,8 +614,9 @@ def get_activation(concept, num_samples=10):
596614
# set to eval mode for sanity
597615
model.eval()
598616

599-
def get_tpcav_activations(concept):
617+
def get_tpcav_activations(concept, num_samples=args.num_samples):
600618
avs_pca = []
619+
num = 0
601620
for seq, chrom in concept.data_iter:
602621
seq = utils.seq_transform_fn(seq)
603622
chrom = utils.chrom_transform_fn(chrom)
@@ -613,9 +632,13 @@ def get_tpcav_activations(concept):
613632
else:
614633
av_pca = av_residual
615634
avs_pca.append(av_pca.detach().cpu())
635+
636+
num += av_pca.shape[0]
637+
if num >= num_samples:
638+
break
616639
with torch.no_grad():
617640
del seq, av, av_projected, av_residual
618-
return torch.cat(avs_pca).detach().cpu()
641+
return torch.cat(avs_pca).detach().cpu()[:num_samples]
619642

620643
# get activations of each concept and train classifier for each pair
621644
pool = Pool()

scripts/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
from itertools import cycle
88

9+
import Bio
910
import numpy as np
1011
import pandas as pd
1112
import pyfaidx
1213
import seqchromloader as scl
1314
import torch
1415
from Bio import SeqIO
1516
from deeplift.dinuc_shuffle import dinuc_shuffle
16-
from models import ConvTowerDomain_v6
1717
from pybedtools import BedTool
1818
from pyfaidx import Fasta
1919
from seq_utils import insert_motif_into_seq, insert_region_into_seq
@@ -289,7 +289,10 @@ def center_windows(df, window_len=1024):
289289
df = df.assign(mid=lambda x: ((x["start"] + x["end"]) / 2).astype(int)).assign(
290290
start=lambda x: x["mid"] - halfR, end=lambda x: x["mid"] + halfR
291291
)
292-
return df[["chrom", "start", "end"]]
292+
if "strand" in df.columns:
293+
return df[["chrom", "start", "end", "strand"]]
294+
else:
295+
return df[["chrom", "start", "end"]]
293296

294297

295298
def collate_seq(batch):
@@ -363,7 +366,8 @@ def seq_dataloader_from_dataframe(
363366
):
364367
seq_df = center_windows(seq_df, window_len=window_len)
365368
seq_df["label"] = -1
366-
seq_df["strand"] = "+"
369+
if not "strand" in seq_df.columns:
370+
seq_df["strand"] = "+"
367371
# print(f"Filtering out concept samples that don't exist in the genome...")
368372
seq_df = scl.filter_chromosomes(seq_df, to_keep=Fasta(genome_fasta).keys())
369373
dl = scl.SeqChromDatasetByDataFrame(
@@ -421,7 +425,8 @@ def chrom_dataloader_from_dataframe(
421425
):
422426
chrom_df = center_windows(chrom_df, window_len=input_window_length)
423427
chrom_df["label"] = -1
424-
chrom_df["strand"] = "+"
428+
if not "strand" in chrom_df.columns:
429+
chrom_df["strand"] = "+"
425430
# print(f"Filtering out concept samples that don't exist in the genome...")
426431
chrom_df = scl.filter_chromosomes(chrom_df, to_keep=Fasta(genome_fasta).keys())
427432
dl = scl.SeqChromDatasetByDataFrame(
@@ -475,7 +480,8 @@ def seq_dataloader(self):
475480
)
476481
seq_df = center_windows(seq_df, window_len=self.window_len)
477482
seq_df["label"] = -1
478-
seq_df["strand"] = "+"
483+
if not "strand" in seq_df.columns:
484+
seq_df["strand"] = "+"
479485
# print(f"Filtering out concept samples that don't exist in the genome...")
480486
seq_df = scl.filter_chromosomes(
481487
seq_df, to_keep=Fasta(self.genome_fasta).keys()
@@ -515,7 +521,8 @@ def chrom_dataloader(self):
515521
)
516522
chrom_df = center_windows(chrom_df, window_len=self.window_len)
517523
chrom_df["label"] = -1
518-
chrom_df["strand"] = "+"
524+
if not "strand" in chrom_df.columns:
525+
chrom_df["strand"] = "+"
519526
# print(f"Filtering out concept samples that don't exist in the genome...")
520527
chrom_df = scl.filter_chromosomes(
521528
chrom_df, to_keep=Fasta(self.genome_fasta).keys()
@@ -610,4 +617,5 @@ def __len__(self):
610617

611618
def reverse_complement(self):
612619
self.consensus = Bio.Seq.reverse_complement(self.consensus)
620+
self.name = self.name + "_rc"
613621
return self

0 commit comments

Comments
 (0)