Skip to content

Commit 95c4141

Browse files
author
Daniel Abercrombie
authored
Merge pull request #118 from dabercro/rough-ml
First Working ML
2 parents 402b45c + b61ed4f commit 95c4141

7 files changed

Lines changed: 231 additions & 29 deletions

File tree

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
'validators',
3232
'tabulate',
3333
'pymongo<3.5.0',
34-
'cx_Oracle'
34+
'cx_Oracle',
35+
'pandas',
36+
'keras',
37+
'tensorflow'
3538
]
3639
)

workflowwebtools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
:author: Daniel Abercrombie <dabercro@mit.edu>
55
"""
66

7-
__version__ = '0.8.2'
7+
__version__ = '0.9.1'
88

99
__all__ = []

workflowwebtools/classifyerrors.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import re
1010

11+
from collections import defaultdict
12+
1113
from .procedures import PROCEDURES
12-
from .globalerrors import check_session
1314

14-
def classifyerror(errorcode, workflow, session=None):
15+
def classifyerror(errorcode, workflow):
1516
"""
1617
Return the most relevant characteristics of an error code for this session.
1718
This will include things like:
@@ -24,8 +25,7 @@ def classifyerror(errorcode, workflow, session=None):
2425
More error types should be added to this function as needed
2526
2627
:param int errorcode: The error code that we want to classify
27-
:param str workflow: the workflow that we want to get the errors from
28-
:param cherrypy.Session session: Is the user's cherrypy session
28+
:param workflowinfo.WorkflowInfo workflow: the workflow that we want to get the errors from
2929
:returns: A tuple of strings describing the key characteristics of the errorcode.
3030
These strings are good for printing directly in web browsers.
3131
The first string is the types of errors reported with this error code.
@@ -36,7 +36,7 @@ def classifyerror(errorcode, workflow, session=None):
3636

3737
procedure = PROCEDURES.get(errorcode, {})
3838

39-
logs = check_session(session).get_workflow(workflow).get_explanation(str(errorcode))
39+
logs = workflow.get_explanation(str(errorcode))
4040

4141
error_re = re.compile(r'[\w\s]+ \(Exit code: (\d+)\)')
4242
error_types = {}
@@ -79,26 +79,29 @@ def classifyerror(errorcode, workflow, session=None):
7979
additional_actions_string.replace(' |br| |br| ', '<br>'))
8080

8181

82-
def get_max_errorcode(workflow, session=None):
82+
def get_max_errorcode(workflow):
8383
"""
8484
Get the errorcode with the most errors for a session
8585
86-
:param str workflow: Is the primary name of the workflow
87-
:param cherrypy.Session session: Is the user's cherrypy session
86+
:param workflowinfo.WorkflowInfo workflow: the workflow that we want to get the errors from
8887
:returns: The error code that appears most often for this workflow
8988
:rtype: int
9089
"""
9190

91+
errors = workflow.get_errors(True)
92+
errors_summed = defaultdict(int)
9293

93-
curs, _, allerrors, _ = check_session(session).info
94-
95-
num_errors = []
96-
97-
for errorcode in allerrors:
98-
output = curs.execute("SELECT SUM(numbererrors) FROM workflows WHERE "
99-
"stepname LIKE '/{0}/%' AND errorcode={1}".\
100-
format(workflow, errorcode))
94+
for codes in errors.values():
95+
for errorcode, sites in codes.items():
96+
numcode = -1 if errorcode == 'NotReported' else int(errorcode)
97+
for num in sites.values():
98+
errors_summed[numcode] += num
10199

102-
num_errors.append(output[0])
100+
output = 0
101+
max_num = 0
102+
for code, num in errors_summed.items():
103+
if num > max_num:
104+
max_num = num
105+
output = code
103106

104-
return allerrors[num_errors.index(max(num_errors))]
107+
return output
Lines changed: 175 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,189 @@
1+
# pylint: disable=missing-docstring, too-complex, invalid-name, too-many-branches, too-many-locals
2+
13
"""
24
A module that evaluates a model and returns the prediction
35
"""
46

5-
#from cmstoolbox import sitereadiness
7+
import os
8+
import random
9+
import itertools
10+
11+
import numpy as np
12+
import pandas as pd
13+
import keras as K
14+
15+
16+
def modified_site_name(site):
17+
site_name = site.split('_')[:-1]
18+
s = ''
19+
for i in site_name:
20+
s = s+i+'_'
21+
s = s.rstrip('_')
22+
return s
23+
24+
def build_table(df, template_table):
25+
sparse_df = template_table.copy()
26+
27+
tier0_sites, tier1_sites, tier2_sites, tier3_sites = [], [], [], []
28+
for i in sparse_df.keys():
29+
if i != 'NA':
30+
if i[1] == '0':
31+
tier0_sites.append(i)
32+
elif i[1] == '1':
33+
tier1_sites.append(i)
34+
elif i[1] == '2':
35+
tier2_sites.append(i)
36+
elif i[1] == '3':
37+
tier3_sites.append(i)
38+
39+
n0, n1, n2, n3 = len(tier0_sites), len(tier1_sites), len(tier2_sites), len(tier3_sites)
40+
for exit_code, site_dict in zip(df.keys(), df.values()):
41+
exit_code = int(exit_code)
42+
for site, count in site_dict.items():
43+
44+
chosen_site = None
45+
site_present_in_training_data = site in sparse_df.keys()
46+
if not site_present_in_training_data:
47+
48+
site = modified_site_name(site)
49+
cond = site in sparse_df.keys()
50+
if cond:
51+
chosen_site = site
52+
53+
print "Detected a site %s which was not present in the training dataset" % site
54+
print "We would use a proxy site for this based on whether it is T1, T2 or T3"
55+
tier = site.split("_")[0][1]
56+
if chosen_site is None:
57+
if tier == '1':
58+
chosen_num = random.randint(0, n1-1)
59+
chosen_site = tier1_sites[chosen_num]
60+
elif tier == '2':
61+
chosen_num = random.randint(0, n2-1)
62+
chosen_site = tier2_sites[chosen_num]
63+
print 'The chosen site is ', chosen_site
64+
elif tier == '3':
65+
chosen_num = random.randint(0, n3-1)
66+
chosen_site = tier3_sites[chosen_num]
67+
print 'The chosen site is ', chosen_site
68+
elif tier == '0':
69+
chosen_num = random.randint(0, n0-1)
70+
chosen_site = tier0_sites[chosen_num]
71+
print 'The chosen site is ', chosen_site
72+
else:
73+
continue
74+
if chosen_site is None:
75+
chosen_site = site
76+
77+
if np.isnan(count) or np.isnan(sparse_df.loc[exit_code, chosen_site]):
78+
sparse_df.loc[exit_code, chosen_site] = 0
79+
else:
80+
sparse_df.loc[exit_code, chosen_site] = count
81+
82+
return sparse_df
83+
84+
85+
def list_of_sites(x):
86+
return [item.keys() for item in x] or ['NA']
87+
88+
def build_table_flatten(x):
89+
d_outer = []
90+
91+
for column in x:
92+
93+
for item in x[column]:
94+
d_outer.append(item)
95+
96+
return d_outer
97+
98+
99+
def pred(errors):
100+
# Needs all of these files to be local
101+
for filename in ['sparse_table.csv', 'actionfile.txt', 'my_model.h5']:
102+
if not os.path.exists(filename):
103+
return ['TBD']
104+
105+
df = pd.DataFrame(columns=('workflow', 'errors'))
106+
base_data = []
107+
for i in errors:
108+
it = i.items()
109+
base_data.extend(it)
110+
111+
for i, dat in enumerate(base_data):
112+
workflow, error_dict = dat[0], dat[1]
113+
114+
if 'NotReported' in error_dict:
115+
error_dict[-1] = error_dict.pop('NotReported')
116+
117+
df.loc[i] = [workflow, error_dict]
118+
119+
template_table = pd.read_csv("sparse_table.csv").set_index("Unnamed: 0")
120+
template_table[:] = 0
121+
df['errors_sites_exit_codes'] = df['errors'].apply(lambda x: x.keys() if x else ['0'])
122+
123+
df['errors_sites_dict'] = df['errors'].apply(lambda x: x.values() if x else [{'NA': 0}])
124+
125+
df['errors_sites_list'] = df['errors_sites_dict'].apply(list_of_sites)
126+
127+
list2d = df['errors_sites_exit_codes'].tolist()
128+
129+
sites_exit_codes = sorted(set(list(itertools.chain.from_iterable(list2d))), key=int)
130+
sites_exit_codes = [str(x) for x in sites_exit_codes]
131+
132+
list2d_step1 = df['errors_sites_list'].tolist()
133+
list2d_step2 = list(itertools.chain.from_iterable(list2d_step1))
134+
site_names = sorted(set(list(itertools.chain.from_iterable(list2d_step2))))
135+
site_names = [str(x) for x in site_names]
136+
137+
df['table_sites'] = df['errors'].apply(
138+
lambda x: build_table(x, template_table))
139+
df['table_sites_flatten'] = df['table_sites'].apply(build_table_flatten)
140+
x_dataframe = df.loc[:, "table_sites_flatten"]
141+
x_matrix = x_dataframe.values
142+
feature_size = len(x_matrix[0])
143+
res = []
144+
clip_length = len(x_matrix[0])
145+
for i in x_matrix:
146+
i_clipped = i[:clip_length]
147+
res.extend(i_clipped)
148+
149+
res = np.asarray(res).reshape(-1, feature_size)
150+
mask = ~np.any(pd.isnull(res), axis=1)
151+
res = res[mask]
152+
153+
model = K.models.load_model('my_model.h5')
154+
predicted_actions_encoded = model.predict(np.array(np.asfarray(res)))
155+
predicted_actions_encoded = np.round(predicted_actions_encoded)
156+
157+
action_code_dictionary = {}
158+
a = np.genfromtxt("actionfile.txt", delimiter='\t', dtype=str)
159+
b = list(i.split(' ') for i in a)
160+
for i in b:
161+
162+
action_code_dictionary[int(i[1])] = i[0]
163+
164+
predicted_actions = []
165+
for i in predicted_actions_encoded:
166+
pos = np.argmax(i)
167+
168+
if pos in action_code_dictionary:
169+
predicted_actions.append(action_code_dictionary[pos])
170+
else:
171+
predicted_actions.append(-1)
172+
173+
K.backend.clear_session()
6174

175+
return predicted_actions
7176

8177

9-
def predict(errors): # pylint: disable=unused-argument
178+
def predict(wf_obj):
10179
"""
11180
Takes the errors for a workflow and makes an action prediction
12-
:param workflowwebtool.workflowinfo.WorkflowInfo errors:
181+
:param workflowwebtool.workflowinfo.WorkflowInfo wf_obj:
13182
The WorkflowInfo object that we want to perform a prediction on
14183
:returns: Prediction results to be passed back to a browser
15184
:rtype: dict
16185
"""
17186

18-
return {'Action': 'TBD'}
187+
return {
188+
'Action': pred([wf_obj.get_errors(True)])[0]
189+
}

workflowwebtools/web/static/js/submit.js

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,18 @@ function buildSubmit (workflow, sitesToRun) {
313313
return output;
314314
}
315315

316+
317+
function showStatus(workflow) {
318+
$.ajax({
319+
url: "/getstatus",
320+
data: {"workflow": workflow},
321+
success: function (status) {
322+
document.getElementById("actionstatus").innerHTML = "Action: " + status.status;
323+
}
324+
});
325+
}
326+
327+
316328
function makeForm(workflow) {
317329

318330
$.ajax({
@@ -326,14 +338,19 @@ function makeForm(workflow) {
326338
form.action = "javascript:;";
327339
form.onsubmit = function () {
328340
submission = buildSubmit(workflow, params.sitestorun);
329-
if (confirm("Will submit " + JSON.stringify(submission)))
341+
if (confirm("Will submit " + JSON.stringify(submission))) {
330342
$.ajax({
331343
url: "/submit2",
332344
type: "POST",
333345
dataType: "json",
334346
contentType : 'application/json',
335-
data: JSON.stringify({documents: submission})
347+
data: JSON.stringify({documents: submission}),
348+
success: function () {
349+
showStatus(workflow);
350+
alert('Action Submitted');
351+
}
336352
});
353+
}
337354
};
338355

339356
addOptions(form, params);
@@ -361,4 +378,5 @@ function makeForm(workflow) {
361378
function prepareSubmit (workflow) {
362379
makeForm(workflow);
363380
setReasons();
381+
showStatus(workflow);
364382
};

workflowwebtools/web/templates/workflowtables2.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
<center id="top">
2020
<h1>${workflow}</h1>
21+
<div id="actionstatus"></div>
2122
<a href="/globalerror2">Global Errors</a> <br>
2223
</center>
2324

workflowwebtools/workflowtools.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def get_status(self, workflow):
185185
return "none"
186186
return "acted" if status else "pending"
187187

188+
@cherrypy.expose
189+
@cherrypy.tools.json_out()
190+
def getstatus(self, workflow):
191+
return {'status': self.get_status(workflow).capitalize()}
188192

189193
def get(self, workflow):
190194
self.wflock.acquire()
@@ -231,10 +235,12 @@ def getworkflows(self, prepid):
231235

232236
@cherrypy.expose
233237
@cherrypy.tools.json_in()
238+
@cherrypy.tools.json_out()
234239
def submit2(self):
235240
input_json = cherrypy.request.json
236241
manageactions.submit2(input_json['documents'])
237-
return 'Done'
242+
self.update_statuses()
243+
return {'message': 'Done'}
238244

239245

240246
@cherrypy.expose
@@ -597,8 +603,8 @@ def classifyerror(self, workflow):
597603
self.seeworkflowlock.acquire()
598604

599605
try:
600-
max_error = classifyerrors.get_max_errorcode(workflow, cherrypy.session)
601-
main_error_class = classifyerrors.classifyerror(max_error, workflow, cherrypy.session)
606+
max_error = classifyerrors.get_max_errorcode(self.get(workflow))
607+
main_error_class = classifyerrors.classifyerror(max_error, self.get(workflow))
602608

603609
output = {
604610
'maxerror': max_error,

0 commit comments

Comments
 (0)