Skip to content

Commit e31bcab

Browse files
committed
support multiple motifs input in motif_enrichment_knn/dist
1 parent c0585ac commit e31bcab

2 files changed

Lines changed: 41 additions & 19 deletions

File tree

SpatialQuery/spatial_query.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def find_fp_dist(self,
235235

236236
def motif_enrichment_knn(self,
237237
ct: str,
238-
motifs: Union[str, List[str]] = None,
238+
motifs: Union[str, List[str], List[List[str]]] = None,
239239
k: int = 30,
240240
min_support: float = 0.5,
241241
max_dist: float = 20,
@@ -295,14 +295,21 @@ def motif_enrichment_knn(self,
295295
motifs = fp['itemsets']
296296
else:
297297
if isinstance(motifs, str):
298+
motifs = [[motifs]]
299+
elif isinstance(motifs, list) and all(isinstance(m, str) for m in motifs):
298300
motifs = [motifs]
301+
# else: List[List[str]], keep as is
299302

300303
labels_unique = self.labels.unique()
301-
motifs_exc = [m for m in motifs if m not in labels_unique]
302-
if len(motifs_exc) != 0:
303-
print(f"Found no {motifs_exc} in {self.label_key}. Ignoring them.")
304-
motifs = [m for m in motifs if m not in motifs_exc]
305-
motifs = [motifs]
304+
filtered_motifs = []
305+
for motif in motifs:
306+
motif_exc = [m for m in motif if m not in labels_unique]
307+
if len(motif_exc) > 0:
308+
print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
309+
valid_motif = [m for m in motif if m in labels_unique]
310+
if len(valid_motif) > 0:
311+
filtered_motifs.append(valid_motif)
312+
motifs = filtered_motifs
306313

307314
if len(motifs) == 0:
308315
# Return empty DataFrame with same structure
@@ -401,7 +408,7 @@ def motif_enrichment_knn(self,
401408

402409
def motif_enrichment_dist(self,
403410
ct: str,
404-
motifs: Union[str, List[str]] = None,
411+
motifs: Union[str, List[str], List[List[str]]] = None,
405412
max_dist: float = 20,
406413
min_size: int = 0,
407414
min_support: float = 0.5,
@@ -454,14 +461,21 @@ def motif_enrichment_dist(self,
454461
motifs = fp['itemsets']
455462
else:
456463
if isinstance(motifs, str):
464+
motifs = [[motifs]]
465+
elif isinstance(motifs, list) and all(isinstance(m, str) for m in motifs):
457466
motifs = [motifs]
467+
# else: List[List[str]], keep as is
458468

459469
labels_unique = self.labels.unique()
460-
motifs_exc = [m for m in motifs if m not in labels_unique]
461-
if len(motifs_exc) != 0:
462-
print(f"Found no {motifs_exc} in {self.label_key}. Ignoring them.")
463-
motifs = [m for m in motifs if m not in motifs_exc]
464-
motifs = [motifs]
470+
filtered_motifs = []
471+
for motif in motifs:
472+
motif_exc = [m for m in motif if m not in labels_unique]
473+
if len(motif_exc) > 0:
474+
print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
475+
valid_motif = [m for m in motif if m in labels_unique]
476+
if len(valid_motif) > 0:
477+
filtered_motifs.append(valid_motif)
478+
motifs = filtered_motifs
465479

466480
if len(motifs) == 0:
467481
# Return empty DataFrame with same structure

SpatialQuery/spatial_query_multiple_fov.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def find_fp_dist(self,
300300

301301
def motif_enrichment_knn(self,
302302
ct: str,
303-
motifs: Union[str, List[str]] = None,
303+
motifs: Union[str, List[str], List[List[str]]] = None,
304304
dataset: Union[str, List[str]] = None,
305305
k: int = 30,
306306
min_support: float = 0.5,
@@ -374,15 +374,23 @@ def motif_enrichment_knn(self,
374374
motifs = fp['itemsets'].tolist()
375375
else:
376376
if isinstance(motifs, str):
377+
motifs = [[motifs]]
378+
elif isinstance(motifs, list) and all(isinstance(m, str) for m in motifs):
377379
motifs = [motifs]
380+
# else: List[List[str]], keep as is
381+
382+
filtered_motifs = []
383+
for motif in motifs:
384+
motif_exc = [m for m in motif if m not in labels_unique_all]
385+
if len(motif_exc) > 0:
386+
print(f"Found no {motif_exc} in {dataset}. Ignoring them.")
387+
valid_motif = [m for m in motif if m in labels_unique_all]
388+
if len(valid_motif) > 0:
389+
filtered_motifs.append(valid_motif)
378390

379-
motifs_exc = [m for m in motifs if m not in labels_unique_all]
380-
if len(motifs_exc) != 0:
381-
print(f"Found no {motifs_exc} in {dataset}. Ignoring them.")
382-
motifs = [m for m in motifs if m not in motifs_exc]
383-
if len(motifs) == 0:
391+
if len(filtered_motifs) == 0:
384392
raise ValueError(f"All cell types in motifs are missed in {self.label_key}.")
385-
motifs = [motifs]
393+
motifs = filtered_motifs
386394

387395
# Initialize dictionaries to store cell IDs if requested
388396
motif_cell_ids = {}

0 commit comments

Comments
 (0)