Skip to content

Commit b64b482

Browse files
committed
DiffuPath cli refactors, adapted to the recoding in cross validation
1 parent 4a10cb7 commit b64b482

2 files changed

Lines changed: 139 additions & 122 deletions

File tree

src/diffupath/cli.py

Lines changed: 138 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,26 @@
55
import json
66
import logging
77
import sys
8+
from collections import defaultdict
89

910
import click
11+
import networkx as nx
1012
from bio2bel.constants import get_global_connection
13+
from diffupath.utils import reduce_dict_dimension
14+
from diffupy.constants import EMOJI, RAW, CSV, JSON
1115
from diffupy.diffuse import diffuse as run_diffusion
12-
from diffupy.process_input import process_input
13-
from diffupy.utils import process_network_from_cli
16+
from diffupy.process_input import process_map_and_format_input_data_for_diff
17+
from diffupy.process_network import get_kernel_from_network_path, process_kernel_from_file, process_graph_from_file
18+
from diffupy.utils import from_json, to_json
1419

1520
from .constants import *
1621
from .cross_validation import cross_validation_by_method
17-
from .input_mapping import process_input_from_cli
18-
from .validation_datasets_parsers import parse_set1, parse_set2, parse_set3
1922

2023
logger = logging.getLogger(__name__)
2124

22-
#: Parsing methods for each dataset
23-
PARSING_METHODS = {
24-
'1': parse_set1,
25-
'2': parse_set2,
26-
'3': parse_set3,
27-
}
25+
GRAPH_PATH = os.path.join(DEFAULT_DIFFUPATH_DIR, 'pickles', 'universe',
26+
'pathme_universe_non_flatten_collapsed_names_no_isolates_16_03_2020.pickle')
27+
KERNEL_PATH = os.path.join(DEFAULT_DIFFUPATH_DIR, 'kernels', 'kernel_regularized_pathme_universe.pickle')
2828

2929

3030
@click.group(help='DiffuPy')
@@ -40,16 +40,16 @@ def diffusion():
4040

4141
@diffusion.command()
4242
@click.option(
43-
'-n', '--network',
44-
help='Path to the network graph or kernel',
43+
'-i', '--input',
44+
help='Input data',
4545
required=True,
46-
type=click.Path(exists=True, dir_okay=False)
46+
type=click.Path(exists=True, dir_okay=True)
4747
)
4848
@click.option(
49-
'-i', '--data',
50-
help='Input data',
51-
required=True,
52-
type=click.Path(exists=True, dir_okay=False)
49+
'-n', '--network',
50+
help='Path to the network graph or kernel',
51+
default=KERNEL_PATH,
52+
type=click.Path(exists=True, dir_okay=True)
5353
)
5454
@click.option(
5555
'-o', '--output',
@@ -61,28 +61,29 @@ def diffusion():
6161
'-m', '--method',
6262
help='Diffusion method',
6363
type=click.Choice(METHODS),
64-
required=True,
64+
default=RAW,
6565
)
6666
@click.option(
6767
'-b', '--binarize',
6868
help='If logFC provided in dataset, convert logFC to binary (e.g., up-regulated entities to 1, down-regulated to '
6969
'-1). For scoring methods that accept quantitative values (i.e., raw & z), node labels can also be codified '
7070
'with LogFC (in this case, set binarize==False).',
7171
type=bool,
72-
default=True,
72+
default=False,
7373
show_default=True,
7474
)
7575
@click.option(
7676
'-t', '--threshold',
7777
help='Codify node labels by applying a threshold to logFC in input.',
78+
default=None,
7879
type=float,
7980
)
8081
@click.option(
8182
'-a', '--absolute_value',
82-
help='Codify node labels by applying threshold to |logFC| in input. If absolute_value is set to False, node labels '
83-
'will be signed.',
83+
help='Codify node labels by applying threshold to | logFC | in input. If absolute_value is set to False,'
84+
'node labels will be signed.',
8485
type=bool,
85-
default=True,
86+
default=False,
8687
show_default=True,
8788
)
8889
@click.option(
@@ -92,72 +93,83 @@ def diffusion():
9293
default=0.05,
9394
show_default=True,
9495
)
95-
def diffuse(
96-
network: str,
97-
data: str,
98-
output: str,
99-
method: str,
100-
binarize: bool,
101-
absolute_value: bool,
102-
threshold: float,
103-
p_value: float,
96+
@click.option(
97+
'-f', '--output_format',
98+
help='Statistical significance (p-value).',
99+
type=float,
100+
default=CSV,
101+
show_default=True,
102+
)
103+
def run(
104+
input: str,
105+
network: str = KERNEL_PATH,
106+
output: str = OUTPUT_DIR,
107+
method: str = RAW,
108+
binarize: bool = False,
109+
threshold: float = None,
110+
absolute_value: bool = False,
111+
p_value: float = 0.05,
112+
output_format: str = CSV
104113
):
105114
"""Run a diffusion method over a network or pre-generated kernel."""
106115
click.secho(f'{EMOJI} Loading graph from {network} {EMOJI}')
107-
graph = process_network_from_cli(network)
108116

109-
click.secho(
110-
f'{EMOJI} Graph loaded with: \n'
111-
f'{graph.number_of_nodes()} nodes\n'
112-
f'{graph.number_of_edges()} edges\n'
113-
f'{EMOJI}'
114-
)
117+
kernel = get_kernel_from_network_path(network)
115118

116-
click.secho(f'Codifying data from {data}.')
119+
click.secho(f'Processing data input from {input}.')
117120

118-
input_scores_dict = process_input(data, method, binarize, absolute_value, p_value, threshold)
121+
input_scores_dict = process_map_and_format_input_data_for_diff(input,
122+
kernel,
123+
method,
124+
binarize,
125+
absolute_value,
126+
p_value,
127+
threshold,
128+
)
119129

120-
click.secho(f'Running the diffusion algorithm.')
130+
click.secho(f'Computing the diffusion algorithm.')
121131

122132
results = run_diffusion(
123133
input_scores_dict,
124134
method,
125-
graph,
135+
k=kernel
126136
)
127137

128-
json.dump(results, output, indent=2)
138+
if output_format is CSV:
139+
results.to_csv(output)
140+
141+
elif output_format is JSON:
142+
json.dump(results, output, indent=2)
129143

130-
click.secho(f'Finished!')
144+
click.secho(f'{EMOJI} Diffusion performed with success. Output located at {output} {EMOJI}')
131145

132146

133147
@diffusion.command()
134148
@click.option(
135-
'-d', '--data',
136-
help='Input data',
137-
required=True,
138-
type=click.Path(exists=True, dir_okay=False),
149+
'-c', '--comparison',
150+
help='Comparison method',
151+
default=BY_METHOD,
152+
show_default=True,
153+
type=click.Choice(EVALUATION_COMPARISONS),
139154
)
140155
@click.option(
141-
'-n', '--network',
142-
help='Path to the network graph or kernel',
143-
required=True,
156+
'-i', '--input_path',
157+
default=os.path.join(ROOT_RESULTS_DIR, 'data', 'input_mappings'),
158+
show_default=True,
159+
type=click.Path(exists=True, dir_okay=True),
160+
)
161+
@click.option(
162+
'-k', '--kernel',
163+
help='Path to the kernel',
164+
default=GRAPH_PATH,
144165
type=click.Path(exists=True, dir_okay=False)
145166
)
146167
@click.option(
147-
'-g', '--graph_path',
168+
'-g', '--graph',
148169
help='Path to the network as a graph',
170+
default=KERNEL_PATH,
149171
type=click.Path(exists=True, dir_okay=False),
150172
)
151-
@click.option(
152-
'-q', '--quantitative', # TODO Automatize if possible, check type of label_input.
153-
help='Generate categorical label_input from labels',
154-
is_flag=False,
155-
)
156-
@click.option(
157-
'-n', '--network_as_graph',
158-
help='If given expects graph else expects as a kernel',
159-
is_flag=False,
160-
)
161173
@click.option(
162174
'-o', '--output',
163175
help='Output path for the results',
@@ -172,78 +184,83 @@ def diffuse(
172184
show_default=True,
173185
type=int,
174186
)
175-
@click.option(
176-
'-c', '--comparison',
177-
help='Comparison method',
178-
default='by_method',
179-
show_default=True,
180-
type=click.Choice(EVALUATION_METHODS),
181-
)
182-
@click.option(
183-
'-k', '--dataset',
184-
help='Key for the datasets presented in the paper',
185-
show_default=True,
186-
default=1,
187-
type=click.Choice(DATASETS),
188-
)
189187
def evaluate(
190-
data: str,
191-
network: str,
192-
graph_path: str,
193-
quantitative: bool, # TODO Automatize if possible, check type of label_input.
194-
network_as_graph: bool, # TODO Automatize if possible, check type of graph.
195-
output: str,
196-
iterations: int,
197-
comparison: str,
198-
dataset: int,
188+
comparison: str = BY_METHOD,
189+
input_path: str = os.path.join(ROOT_RESULTS_DIR, 'data', 'input_mappings'),
190+
graph: str = GRAPH_PATH,
191+
kernel: str = KERNEL_PATH,
192+
output: str = OUTPUT_DIR,
193+
iterations: int = 100,
199194
):
200195
"""Evaluate a kernel/network on one of the three presented datasets."""
201-
click.secho(f'{EMOJI} Loading label_input for cross-validation... {EMOJI}')
202-
203-
if not network_as_graph and not graph_path:
204-
raise ValueError("Network not provided in graph format, which is required for evaluation.")
205-
206-
_, kernel, labels_mapping, graph = process_input_from_cli(
207-
PARSING_METHODS[dataset],
208-
network,
209-
data,
210-
network_as_graph,
211-
quantitative,
212-
)
196+
click.secho(f'{EMOJI} Loading network for random cross-validation... {EMOJI}')
197+
graph = process_graph_from_file(graph)
198+
kernel = process_kernel_from_file(kernel)
199+
200+
nx.number_of_isolates(graph)
201+
graph.remove_nodes_from({
202+
node
203+
for node in nx.isolates(graph)
204+
})
205+
206+
graph.summarize()
207+
208+
click.secho(f'{EMOJI} Loading data for cross-validation... {EMOJI}')
209+
MAPPING_PATH_DATASET_1 = os.path.join(input_path, 'dataset_1_mapping.json')
210+
dataset1_mapping_by_database_and_entity = from_json(MAPPING_PATH_DATASET_1)
211+
dataset1_mapping_by_database = reduce_dict_dimension(dataset1_mapping_by_database_and_entity)
212+
dataset1_mapping_all_labels = {entity: entity_value
213+
for entity_type, entity_set in dataset1_mapping_by_database.items()
214+
for entity, entity_value in entity_set.items()
215+
}
216+
217+
MAPPING_PATH_DATASET_2 = os.path.join(input_path, 'dataset_2_mapping.json')
218+
dataset2_mapping_by_database_and_entity = from_json(MAPPING_PATH_DATASET_2)
219+
dataset2_mapping_by_database = reduce_dict_dimension(dataset2_mapping_by_database_and_entity)
220+
dataset2_mapping_all_labels = {entity: entity_value
221+
for entity_type, entity_set in dataset2_mapping_by_database.items()
222+
for entity, entity_value in entity_set.items()
223+
}
224+
225+
MAPPING_PATH_DATASET_3 = os.path.join(input_path, 'dataset_3_mapping.json')
226+
dataset3_mapping_by_database_and_entity = from_json(MAPPING_PATH_DATASET_3)
227+
dataset3_mapping_by_database = reduce_dict_dimension(dataset3_mapping_by_database_and_entity)
228+
dataset3_mapping_all_labels = {entity: entity_value
229+
for entity_type, entity_set in dataset3_mapping_by_database.items()
230+
for entity, entity_value in entity_set.items()
231+
}
232+
233+
if comparison == BY_METHOD:
234+
click.secho(f'{EMOJI} Evaluating by method... {EMOJI}')
213235

214-
if not network_as_graph:
215-
graph = process_network_from_cli(graph_path)
236+
metrics_by_method = defaultdict(lambda: defaultdict(lambda: list))
216237

217-
if comparison == 'by_method':
218-
click.secho(f'{EMOJI} Evaluating by method... {EMOJI}')
238+
click.secho(f'{EMOJI} Running cross_validation_by_method for Dataset 1... {EMOJI}')
239+
metrics_by_method['auroc']['Dataset 1'], metrics_by_method['auprc']['Dataset 1'] = cross_validation_by_method(
240+
dataset1_mapping_all_labels,
241+
graph,
242+
kernel,
243+
k=iterations)
219244

220-
auroc_metrics, auprc_metrics = cross_validation_by_method(
221-
labels_mapping,
245+
click.secho(f'{EMOJI} Running cross_validation_by_method for Dataset 2... {EMOJI}')
246+
metrics_by_method['auroc']['Dataset 2'], metrics_by_method['auprc']['Dataset 2'] = cross_validation_by_method(
247+
dataset2_mapping_all_labels,
222248
graph,
223249
kernel,
224-
k=iterations,
225-
)
226-
elif comparison == 'by_db':
227-
click.secho(f'{EMOJI} Evaluating by database... {EMOJI}')
228-
229-
# TODO to adapt from 'get_one_x_in_cv_inputs_from_subsets', and label_input treatment subset division.
230-
auroc_metrics, auprc_metrics = cross_validation_by_method(
231-
labels_mapping,
250+
k=iterations)
251+
252+
click.secho(f'{EMOJI} Running cross_validation_by_method for Dataset 3... {EMOJI}')
253+
metrics_by_method['auroc']['Dataset 3'], metrics_by_method['auprc']['Dataset 3'] = cross_validation_by_method(
254+
dataset3_mapping_all_labels,
232255
graph,
233256
kernel,
234-
k=iterations,
235-
)
257+
k=iterations)
258+
259+
236260
else:
237261
raise ValueError("The comparison method provided not match any provided method.")
238262

239-
with open(os.path.join(output, 'metrics.json'), 'w') as outfile:
240-
json.dump(
241-
{'auroc_metrics': auroc_metrics,
242-
'auprc_metrics': auprc_metrics
243-
},
244-
outfile,
245-
indent=2,
246-
)
263+
to_json(metrics_by_method, output)
247264

248265
click.secho(f'{EMOJI} Random cross-validation performed with success. Output located at {output}... {EMOJI}')
249266

src/diffupath/topological_analyses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import networkx as nx
1010
import numpy as np
1111
from diffupy.matrix import LaplacianMatrix, Matrix
12-
from diffupy.utils import get_simple_graph_from_multigraph
12+
from diffupy.process_network import get_simple_graph_from_multigraph
1313

1414

1515
def generate_pagerank_baseline(graph: nx.Graph, background_mat: Matrix) -> Matrix:

0 commit comments

Comments
 (0)