|
| 1 | +# pylint: disable=missing-docstring, too-complex, invalid-name, too-many-branches, too-many-locals |
| 2 | + |
1 | 3 | """ |
2 | 4 | A module that evaluates a model and returns the prediction |
3 | 5 | """ |
4 | 6 |
|
5 | | -#from cmstoolbox import sitereadiness |
| 7 | +import os |
| 8 | +import random |
| 9 | +import itertools |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import pandas as pd |
| 13 | +import keras as K |
| 14 | + |
| 15 | + |
| 16 | +def modified_site_name(site): |
| 17 | + site_name = site.split('_')[:-1] |
| 18 | + s = '' |
| 19 | + for i in site_name: |
| 20 | + s = s+i+'_' |
| 21 | + s = s.rstrip('_') |
| 22 | + return s |
| 23 | + |
| 24 | +def build_table(df, template_table): |
| 25 | + sparse_df = template_table.copy() |
| 26 | + |
| 27 | + tier0_sites, tier1_sites, tier2_sites, tier3_sites = [], [], [], [] |
| 28 | + for i in sparse_df.keys(): |
| 29 | + if i != 'NA': |
| 30 | + if i[1] == '0': |
| 31 | + tier0_sites.append(i) |
| 32 | + elif i[1] == '1': |
| 33 | + tier1_sites.append(i) |
| 34 | + elif i[1] == '2': |
| 35 | + tier2_sites.append(i) |
| 36 | + elif i[1] == '3': |
| 37 | + tier3_sites.append(i) |
| 38 | + |
| 39 | + n0, n1, n2, n3 = len(tier0_sites), len(tier1_sites), len(tier2_sites), len(tier3_sites) |
| 40 | + for exit_code, site_dict in zip(df.keys(), df.values()): |
| 41 | + exit_code = int(exit_code) |
| 42 | + for site, count in site_dict.items(): |
| 43 | + |
| 44 | + chosen_site = None |
| 45 | + site_present_in_training_data = site in sparse_df.keys() |
| 46 | + if not site_present_in_training_data: |
| 47 | + |
| 48 | + site = modified_site_name(site) |
| 49 | + cond = site in sparse_df.keys() |
| 50 | + if cond: |
| 51 | + chosen_site = site |
| 52 | + |
| 53 | + print "Detected a site %s which was not present in the training dataset" % site |
| 54 | + print "We would use a proxy site for this based on whether it is T1, T2 or T3" |
| 55 | + tier = site.split("_")[0][1] |
| 56 | + if chosen_site is None: |
| 57 | + if tier == '1': |
| 58 | + chosen_num = random.randint(0, n1-1) |
| 59 | + chosen_site = tier1_sites[chosen_num] |
| 60 | + elif tier == '2': |
| 61 | + chosen_num = random.randint(0, n2-1) |
| 62 | + chosen_site = tier2_sites[chosen_num] |
| 63 | + print 'The chosen site is ', chosen_site |
| 64 | + elif tier == '3': |
| 65 | + chosen_num = random.randint(0, n3-1) |
| 66 | + chosen_site = tier3_sites[chosen_num] |
| 67 | + print 'The chosen site is ', chosen_site |
| 68 | + elif tier == '0': |
| 69 | + chosen_num = random.randint(0, n0-1) |
| 70 | + chosen_site = tier0_sites[chosen_num] |
| 71 | + print 'The chosen site is ', chosen_site |
| 72 | + else: |
| 73 | + continue |
| 74 | + if chosen_site is None: |
| 75 | + chosen_site = site |
| 76 | + |
| 77 | + if np.isnan(count) or np.isnan(sparse_df.loc[exit_code, chosen_site]): |
| 78 | + sparse_df.loc[exit_code, chosen_site] = 0 |
| 79 | + else: |
| 80 | + sparse_df.loc[exit_code, chosen_site] = count |
| 81 | + |
| 82 | + return sparse_df |
| 83 | + |
| 84 | + |
| 85 | +def list_of_sites(x): |
| 86 | + return [item.keys() for item in x] or ['NA'] |
| 87 | + |
| 88 | +def build_table_flatten(x): |
| 89 | + d_outer = [] |
| 90 | + |
| 91 | + for column in x: |
| 92 | + |
| 93 | + for item in x[column]: |
| 94 | + d_outer.append(item) |
| 95 | + |
| 96 | + return d_outer |
| 97 | + |
| 98 | + |
| 99 | +def pred(errors): |
| 100 | + # Needs all of these files to be local |
| 101 | + for filename in ['sparse_table.csv', 'actionfile.txt', 'my_model.h5']: |
| 102 | + if not os.path.exists(filename): |
| 103 | + return ['TBD'] |
| 104 | + |
| 105 | + df = pd.DataFrame(columns=('workflow', 'errors')) |
| 106 | + base_data = [] |
| 107 | + for i in errors: |
| 108 | + it = i.items() |
| 109 | + base_data.extend(it) |
| 110 | + |
| 111 | + for i, dat in enumerate(base_data): |
| 112 | + workflow, error_dict = dat[0], dat[1] |
| 113 | + |
| 114 | + if 'NotReported' in error_dict: |
| 115 | + error_dict[-1] = error_dict.pop('NotReported') |
| 116 | + |
| 117 | + df.loc[i] = [workflow, error_dict] |
| 118 | + |
| 119 | + template_table = pd.read_csv("sparse_table.csv").set_index("Unnamed: 0") |
| 120 | + template_table[:] = 0 |
| 121 | + df['errors_sites_exit_codes'] = df['errors'].apply(lambda x: x.keys() if x else ['0']) |
| 122 | + |
| 123 | + df['errors_sites_dict'] = df['errors'].apply(lambda x: x.values() if x else [{'NA': 0}]) |
| 124 | + |
| 125 | + df['errors_sites_list'] = df['errors_sites_dict'].apply(list_of_sites) |
| 126 | + |
| 127 | + list2d = df['errors_sites_exit_codes'].tolist() |
| 128 | + |
| 129 | + sites_exit_codes = sorted(set(list(itertools.chain.from_iterable(list2d))), key=int) |
| 130 | + sites_exit_codes = [str(x) for x in sites_exit_codes] |
| 131 | + |
| 132 | + list2d_step1 = df['errors_sites_list'].tolist() |
| 133 | + list2d_step2 = list(itertools.chain.from_iterable(list2d_step1)) |
| 134 | + site_names = sorted(set(list(itertools.chain.from_iterable(list2d_step2)))) |
| 135 | + site_names = [str(x) for x in site_names] |
| 136 | + |
| 137 | + df['table_sites'] = df['errors'].apply( |
| 138 | + lambda x: build_table(x, template_table)) |
| 139 | + df['table_sites_flatten'] = df['table_sites'].apply(build_table_flatten) |
| 140 | + x_dataframe = df.loc[:, "table_sites_flatten"] |
| 141 | + x_matrix = x_dataframe.values |
| 142 | + feature_size = len(x_matrix[0]) |
| 143 | + res = [] |
| 144 | + clip_length = len(x_matrix[0]) |
| 145 | + for i in x_matrix: |
| 146 | + i_clipped = i[:clip_length] |
| 147 | + res.extend(i_clipped) |
| 148 | + |
| 149 | + res = np.asarray(res).reshape(-1, feature_size) |
| 150 | + mask = ~np.any(pd.isnull(res), axis=1) |
| 151 | + res = res[mask] |
| 152 | + |
| 153 | + model = K.models.load_model('my_model.h5') |
| 154 | + predicted_actions_encoded = model.predict(np.array(np.asfarray(res))) |
| 155 | + predicted_actions_encoded = np.round(predicted_actions_encoded) |
| 156 | + |
| 157 | + action_code_dictionary = {} |
| 158 | + a = np.genfromtxt("actionfile.txt", delimiter='\t', dtype=str) |
| 159 | + b = list(i.split(' ') for i in a) |
| 160 | + for i in b: |
| 161 | + |
| 162 | + action_code_dictionary[int(i[1])] = i[0] |
| 163 | + |
| 164 | + predicted_actions = [] |
| 165 | + for i in predicted_actions_encoded: |
| 166 | + pos = np.argmax(i) |
| 167 | + |
| 168 | + if pos in action_code_dictionary: |
| 169 | + predicted_actions.append(action_code_dictionary[pos]) |
| 170 | + else: |
| 171 | + predicted_actions.append(-1) |
| 172 | + |
| 173 | + K.backend.clear_session() |
6 | 174 |
|
| 175 | + return predicted_actions |
7 | 176 |
|
8 | 177 |
|
9 | | -def predict(errors): # pylint: disable=unused-argument |
| 178 | +def predict(wf_obj): |
10 | 179 | """ |
11 | 180 | Takes the errors for a workflow and makes an action prediction |
12 | | - :param workflowwebtool.workflowinfo.WorkflowInfo errors: |
| 181 | + :param workflowwebtool.workflowinfo.WorkflowInfo wf_obj: |
13 | 182 | The WorkflowInfo object that we want to perform a prediction on |
14 | 183 | :returns: Prediction results to be passed back to a browser |
15 | 184 | :rtype: dict |
16 | 185 | """ |
17 | 186 |
|
18 | | - return {'Action': 'TBD'} |
| 187 | + return { |
| 188 | + 'Action': pred([wf_obj.get_errors(True)])[0] |
| 189 | + } |
0 commit comments