@@ -182,9 +182,9 @@ def make_predictions_and_save(avdl, name):
182182def construct_motif_concept_dataloader_from_control (
183183 control_seq_bed_df ,
184184 genome_fasta ,
185- motifs :list ,
185+ motifs : list ,
186186 num_motifs = 128 ,
187- motif_mode = ' pwm' ,
187+ motif_mode = " pwm" ,
188188 start_buffer = 0 ,
189189 end_buffer = 0 ,
190190 batch_size = 8 ,
@@ -343,6 +343,11 @@ def pair(arg):
343343 parser .add_argument (
344344 "--num-motifs" , type = int , default = 12 , help = "Number of motifs to insert"
345345 )
346+ parser .add_argument (
347+ "--include-reverse-complement" ,
348+ action = "store_true" ,
349+ help = "Use both forward and reverse complement version of motifs" ,
350+ )
346351 parser .add_argument (
347352 "--num-samples-per-concept" ,
348353 type = int ,
@@ -422,21 +427,29 @@ def pair(arg):
422427 concepts = []
423428 ## custom motifs, use the first control concept as a template
424429 if args .custom_motifs is not None :
425- df = pd .read_table (args .custom_motifs , names = [' motif_name' , ' consensus_seq' ])
430+ df = pd .read_table (args .custom_motifs , names = [" motif_name" , " consensus_seq" ])
426431 for m in np .unique (df .motif_name ):
427432 motif_name = m
428- consensus_seqs = df .loc [df .motif_name == m , 'consensus_seq' ].tolist () # take all consensus seqs that correspond to the same motif name
433+ consensus_seqs = df .loc [
434+ df .motif_name == m , "consensus_seq"
435+ ].tolist () # take all consensus seqs that correspond to the same motif name
429436 motifs = []
430437 for i , c in enumerate (consensus_seqs ):
431438 motif = utils .CustomMotif (f"{ m } _{ i } " , c )
432439 motifs .append (motif )
440+ if (
441+ args .include_reverse_complement
442+ ): # add reverse complement if specified
443+ motif_rc = motif .reverse_complement ()
444+ motifs .append (motif_rc )
445+
433446 cn = f"{ motif_name } "
434447 seq_dl = construct_motif_concept_dataloader_from_control (
435448 random_regions_df ,
436449 args .genome_fasta_file ,
437450 motifs = motifs ,
438451 num_motifs = args .num_motifs ,
439- motif_mode = ' consensus' ,
452+ motif_mode = " consensus" ,
440453 batch_size = BATCH_SIZE ,
441454 infinite = False ,
442455 )
@@ -451,13 +464,18 @@ def pair(arg):
451464 if args .meme_motifs is not None :
452465 with open (args .meme_motifs ) as f :
453466 for motif in Bio_motifs .parse (f , fmt = "MINIMAL" ):
467+ motifs = []
468+ motifs .append (motif )
469+ if args .include_reverse_complement :
470+ motif_rc = motif .reverse_complement ()
471+ motifs .append (motif_rc )
454472 cn = f"{ motif .name .replace ('/' , '-' )} "
455473 seq_dl = construct_motif_concept_dataloader_from_control (
456474 random_regions_df ,
457475 args .genome_fasta_file ,
458- motifs = [ motif ] ,
476+ motifs = motifs ,
459477 num_motifs = args .num_motifs ,
460- motif_mode = ' pwm' ,
478+ motif_mode = " pwm" ,
461479 batch_size = BATCH_SIZE ,
462480 infinite = False ,
463481 )
@@ -528,7 +546,7 @@ def pair(arg):
528546 logger .info (concepts )
529547
530548 # register hook
531- def get_activation (concept , num_samples = 10 ):
549+ def get_activation (concept , num_samples = args . num_samples_per_concept ):
532550 avs = []
533551 num = 0
534552 for seq , chrom in concept .data_iter :
@@ -596,8 +614,9 @@ def get_activation(concept, num_samples=10):
596614 # set to eval mode for sanity
597615 model .eval ()
598616
599- def get_tpcav_activations (concept ):
617+ def get_tpcav_activations (concept , num_samples = args . num_samples ):
600618 avs_pca = []
619+ num = 0
601620 for seq , chrom in concept .data_iter :
602621 seq = utils .seq_transform_fn (seq )
603622 chrom = utils .chrom_transform_fn (chrom )
@@ -613,9 +632,13 @@ def get_tpcav_activations(concept):
613632 else :
614633 av_pca = av_residual
615634 avs_pca .append (av_pca .detach ().cpu ())
635+
636+ num += av_pca .shape [0 ]
637+ if num >= num_samples :
638+ break
616639 with torch .no_grad ():
617640 del seq , av , av_projected , av_residual
618- return torch .cat (avs_pca ).detach ().cpu ()
641+ return torch .cat (avs_pca ).detach ().cpu ()[: num_samples ]
619642
620643 # get activations of each concept and train classifier for each pair
621644 pool = Pool ()
0 commit comments