Skip to content

Commit 2371f6e

Browse files
committed
print_dict_dimensions and general refactor in DiffuPath utils
1 parent b5bbc88 commit 2371f6e

1 file changed

Lines changed: 46 additions & 54 deletions

File tree

src/diffupath/utils.py

Lines changed: 46 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import pickle
88
import random
9-
109
from statistics import mean
1110

1211
import numpy as np
@@ -29,23 +28,31 @@ def to_pickle(to_pickle, output):
2928
pickle.dump(to_pickle, file)
3029

3130

32-
def print_dict_dimensions(entities_db, title):
31+
def print_dict_dimensions(entities_db, title='', message='Total number of '):
3332
"""Print dimension of the dictionary"""
34-
total = 0
33+
total = set()
34+
m = f'{title}\n'
3535

3636
for k1, v1 in entities_db.items():
3737
m = ''
3838
if isinstance(v1, dict):
3939
for k2, v2 in v1.items():
4040
m += f'{k2}({len(v2)}), '
41-
total += len(v2)
41+
total.update(v2)
4242
else:
4343
m += f'{len(v1)} '
44-
total += len(v1)
44+
total.update(v1)
45+
46+
print(f'{message} {k1}: {m} ')
4547

46-
print(f'Total number of {k1}: {m} ')
48+
print(f'Total: {len(total)} ')
4749

48-
print(f'Total: {total} ')
50+
51+
def print_dict(dict_to_print, message=''):
52+
"""Print dimension of the dictionary"""
53+
54+
for k1, v1 in dict_to_print.items():
55+
print(f'{message} {k1}: {v1} ')
4956

5057

5158
def get_labels_set_from_dict(entities):
@@ -58,49 +65,34 @@ def get_labels_set_from_dict(entities):
5865

5966
def reduce_dict_dimension(dict):
6067
"""Reduce dictionary dimension."""
61-
return {
62-
k: set(itertools.chain.from_iterable(entities.values()))
63-
for k, entities in dict.items()
64-
}
65-
66-
67-
def check_substrings(dataset_nodes, db_nodes):
68-
mapping_substrings = set()
69-
70-
for entity in dataset_nodes:
71-
if isinstance(entity, tuple):
72-
for subentity in entity:
73-
for entity_db in db_nodes:
74-
if isinstance(entity_db, tuple):
75-
for subentity_db in entity_db:
76-
if subentity_db in subentity or subentity in subentity_db:
77-
mapping_substrings.add(entity_db)
78-
break
79-
break
80-
else:
81-
if entity_db in subentity or subentity in entity_db:
82-
mapping_substrings.add(entity_db)
83-
break
84-
else:
85-
for entity_db in db_nodes:
86-
if isinstance(entity_db, tuple):
87-
for subentity_db in entity_db:
88-
if subentity_db in entity or entity in subentity_db:
89-
mapping_substrings.add(entity_db)
90-
break
91-
break
92-
else:
93-
if entity_db in entity or entity in entity_db:
94-
mapping_substrings.add(entity_db)
95-
break
96-
97-
return mapping_substrings
68+
reduced_dict = {}
69+
70+
for k1, entities1 in dict.items():
71+
for k2, entities2 in entities1.items():
72+
if k1 in reduced_dict.keys():
73+
reduced_dict[k1].update(entities2)
74+
else:
75+
reduced_dict[k1] = entities2
76+
77+
return reduced_dict
78+
9879

9980
def split_random_two_subsets(to_split):
100-
half_1 = random.sample(population=list(to_split), k=int(len(to_split) / 2))
101-
half_2 = list(set(to_split) - set(half_1))
81+
"""Split random two subsets."""
82+
if isinstance(to_split, dict):
83+
to_split_labels = list(to_split.keys())
84+
else:
85+
to_split_labels = to_split
86+
87+
half_1 = random.sample(population=list(to_split_labels), k=int(len(to_split_labels) / 2))
88+
half_2 = list(set(to_split_labels) - set(half_1))
89+
90+
if isinstance(to_split, dict):
91+
return {entity_label: to_split[entity_label] for entity_label in half_1}, \
92+
{entity_label: to_split[entity_label] for entity_label in half_2}
93+
else:
94+
return half_1, half_2
10295

103-
return half_1, half_2
10496

10597
def hide_true_positives(to_split, k=0.5):
10698
"""Hide relative number of labels."""
@@ -195,18 +187,18 @@ def get_count_and_labels_from_two_dim_dict(mapping_by_database_and_entity):
195187

196188
# entity_type_map = {'metabolite_nodes': 'metabolite', 'mirna_nodes': 'micrornas', 'gene_nodes': 'genes', 'bp_nodes': 'bps'}
197189

198-
for db_name, entities_by_type in mapping_by_database_and_entity.items():
190+
for type_label, entities in mapping_by_database_and_entity.items():
199191
db_count = []
200192
db_percentage = []
201193

202-
db_labels.append(db_name)
194+
db_labels.append(type_label)
203195

204-
if not types_labels:
205-
types_labels = entities_by_type[0].keys()
196+
if types_labels == []:
197+
types_labels = list(entities.keys())
206198

207-
for entity_type, entities_tupple in entities_by_type[0].items():
208-
db_count.append(len(entities_tupple[0]))
209-
db_percentage.append(entities_tupple[1])
199+
for entity_type, entities_tupple in entities.items():
200+
db_count.append(entities_tupple[1])
201+
db_percentage.append(entities_tupple[0])
210202

211203
all_count.append(db_count)
212204
all_percentage.append(db_percentage)

0 commit comments

Comments
 (0)