Skip to content

Commit 4a10cb7

Browse files
committed
cross_validation_by_subgraph reimplemented for validation by_db and by_entity_type
1 parent fb6ecca commit 4a10cb7

1 file changed

Lines changed: 56 additions & 2 deletions

File tree

src/diffupath/cross_validation.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,62 @@ def cross_validation_by_method(data_input,
5959

6060
}
6161

62-
for method, validation_set in method_validation_inputs.items():
63-
auroc, auprc = get_metrics(*validation_set)
62+
for method, validation_set in method_validation_scores.items():
63+
auroc, auprc = _get_metrics(*validation_set)
64+
auroc_metrics[method].append(auroc)
65+
auprc_metrics[method].append(auprc)
66+
67+
return auroc_metrics, auprc_metrics
68+
69+
70+
def cross_validation_by_subgraph(data_input,
71+
graph,
72+
graph_parameter,
73+
type_list,
74+
universe_kernel=None,
75+
z_normalization=False,
76+
k=100
77+
):
78+
"""Cross validation by subgraph."""
79+
auroc_metrics = defaultdict(list)
80+
auprc_metrics = defaultdict(list)
81+
82+
# Pre-process kernel for subgraphs
83+
kernels = {parameter: regularised_laplacian_kernel(get_subgraph_by_annotation_value(graph,
84+
graph_parameter,
85+
parameter)
86+
)
87+
for parameter in tqdm(type_list, 'Generate kernels from subgraphs')
88+
}
89+
90+
for _ in tqdm(range(k), 'Computate validation scores'):
91+
subgraph_validation_scores = {}
92+
93+
for type, kernel in kernels.items():
94+
95+
if universe_kernel is None:
96+
universe_kernel = kernel
97+
98+
99+
input_diff, validation_diff = _get_random_cv_split_input_and_validation(data_input[type],
100+
kernel)
101+
input_diff_universe, validation_diff_universe = _get_random_cv_split_input_and_validation(data_input[type],
102+
universe_kernel)
103+
104+
scores_on_subgraph = diffuse_raw(graph=None,
105+
scores=input_diff,
106+
k=kernel,
107+
z=z_normalization)
108+
scores_on_universe = diffuse_raw(graph=None,
109+
scores=input_diff_universe,
110+
k=universe_kernel,
111+
z=z_normalization)
112+
113+
subgraph_validation_scores[type + 'on ' + type] = (validation_diff, scores_on_subgraph)
114+
subgraph_validation_scores[type + 'on PathMeUniverse'] = (validation_diff_universe, scores_on_universe)
115+
116+
for method, validation_set in subgraph_validation_scores.items():
117+
auroc, auprc = _get_metrics(*validation_set)
64118
auroc_metrics[method].append(auroc)
65119
auprc_metrics[method].append(auprc)
66120

0 commit comments

Comments
 (0)