@@ -235,7 +235,7 @@ def find_fp_dist(self,
235235
236236 def motif_enrichment_knn (self ,
237237 ct : str ,
238- motifs : Union [str , List [str ]] = None ,
238+ motifs : Union [str , List [str ], List [ List [ str ]] ] = None ,
239239 k : int = 30 ,
240240 min_support : float = 0.5 ,
241241 max_dist : float = 20 ,
@@ -295,14 +295,21 @@ def motif_enrichment_knn(self,
295295 motifs = fp ['itemsets' ]
296296 else :
297297 if isinstance (motifs , str ):
298+ motifs = [[motifs ]]
299+ elif isinstance (motifs , list ) and all (isinstance (m , str ) for m in motifs ):
298300 motifs = [motifs ]
301+ # else: List[List[str]], keep as is
299302
300303 labels_unique = self .labels .unique ()
301- motifs_exc = [m for m in motifs if m not in labels_unique ]
302- if len (motifs_exc ) != 0 :
303- print (f"Found no { motifs_exc } in { self .label_key } . Ignoring them." )
304- motifs = [m for m in motifs if m not in motifs_exc ]
305- motifs = [motifs ]
304+ filtered_motifs = []
305+ for motif in motifs :
306+ motif_exc = [m for m in motif if m not in labels_unique ]
307+ if len (motif_exc ) > 0 :
308+ print (f"Found no { motif_exc } in { self .label_key } . Ignoring them." )
309+ valid_motif = [m for m in motif if m in labels_unique ]
310+ if len (valid_motif ) > 0 :
311+ filtered_motifs .append (valid_motif )
312+ motifs = filtered_motifs
306313
307314 if len (motifs ) == 0 :
308315 # Return empty DataFrame with same structure
@@ -401,7 +408,7 @@ def motif_enrichment_knn(self,
401408
402409 def motif_enrichment_dist (self ,
403410 ct : str ,
404- motifs : Union [str , List [str ]] = None ,
411+ motifs : Union [str , List [str ], List [ List [ str ]] ] = None ,
405412 max_dist : float = 20 ,
406413 min_size : int = 0 ,
407414 min_support : float = 0.5 ,
@@ -454,14 +461,21 @@ def motif_enrichment_dist(self,
454461 motifs = fp ['itemsets' ]
455462 else :
456463 if isinstance (motifs , str ):
464+ motifs = [[motifs ]]
465+ elif isinstance (motifs , list ) and all (isinstance (m , str ) for m in motifs ):
457466 motifs = [motifs ]
467+ # else: List[List[str]], keep as is
458468
459469 labels_unique = self .labels .unique ()
460- motifs_exc = [m for m in motifs if m not in labels_unique ]
461- if len (motifs_exc ) != 0 :
462- print (f"Found no { motifs_exc } in { self .label_key } . Ignoring them." )
463- motifs = [m for m in motifs if m not in motifs_exc ]
464- motifs = [motifs ]
470+ filtered_motifs = []
471+ for motif in motifs :
472+ motif_exc = [m for m in motif if m not in labels_unique ]
473+ if len (motif_exc ) > 0 :
474+ print (f"Found no { motif_exc } in { self .label_key } . Ignoring them." )
475+ valid_motif = [m for m in motif if m in labels_unique ]
476+ if len (valid_motif ) > 0 :
477+ filtered_motifs .append (valid_motif )
478+ motifs = filtered_motifs
465479
466480 if len (motifs ) == 0 :
467481 # Return empty DataFrame with same structure
0 commit comments