Skip to content

Commit 30b482f

Browse files
author
Daniel Abercrombie
committed
Merge branch 'master' into quick-fix
Conflicts: workflowwebtools/__init__.py workflowwebtools/predict/evaluate.py
2 parents d0a81f9 + 95c4141 commit 30b482f

10 files changed

Lines changed: 284 additions & 54 deletions

File tree

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@
2222
'cherrypy<18.0.0',
2323
'mako',
2424
'numpy>=1.6.1',
25-
'scipy>=0.19.1',
25+
'scipy==1.1.0',
2626
'sklearn',
2727
'passlib>=1.6',
2828
'bcrypt',
2929
'pyOpenSSL',
3030
'pyyaml',
3131
'validators',
3232
'tabulate',
33-
'pymongo<3.5.0'
33+
'pymongo<3.5.0',
34+
'cx_Oracle',
35+
'pandas',
36+
'keras',
37+
'tensorflow'
3438
]
3539
)

workflowwebtools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
:author: Daniel Abercrombie <dabercro@mit.edu>
55
"""
66

7-
__version__ = '0.7.3'
7+
__version__ = '0.9.2'
88

99
__all__ = []

workflowwebtools/classifyerrors.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import re
1010

11+
from collections import defaultdict
12+
1113
from .procedures import PROCEDURES
12-
from .globalerrors import check_session
1314

14-
def classifyerror(errorcode, workflow, session=None):
15+
def classifyerror(errorcode, workflow):
1516
"""
1617
Return the most relevant characteristics of an error code for this session.
1718
This will include things like:
@@ -24,8 +25,7 @@ def classifyerror(errorcode, workflow, session=None):
2425
More error types should be added to this function as needed
2526
2627
:param int errorcode: The error code that we want to classify
27-
:param str workflow: the workflow that we want to get the errors from
28-
:param cherrypy.Session session: Is the user's cherrypy session
28+
:param workflowinfo.WorkflowInfo workflow: the workflow that we want to get the errors from
2929
:returns: A tuple of strings describing the key characteristics of the errorcode.
3030
These strings are good for printing directly in web browsers.
3131
The first string is the types of errors reported with this error code.
@@ -36,7 +36,7 @@ def classifyerror(errorcode, workflow, session=None):
3636

3737
procedure = PROCEDURES.get(errorcode, {})
3838

39-
logs = check_session(session).get_workflow(workflow).get_explanation(str(errorcode))
39+
logs = workflow.get_explanation(str(errorcode))
4040

4141
error_re = re.compile(r'[\w\s]+ \(Exit code: (\d+)\)')
4242
error_types = {}
@@ -79,26 +79,29 @@ def classifyerror(errorcode, workflow, session=None):
7979
additional_actions_string.replace(' |br| |br| ', '<br>'))
8080

8181

82-
def get_max_errorcode(workflow, session=None):
82+
def get_max_errorcode(workflow):
8383
"""
8484
Get the errorcode with the most errors for a session
8585
86-
:param str workflow: Is the primary name of the workflow
87-
:param cherrypy.Session session: Is the user's cherrypy session
86+
:param workflowinfo.WorkflowInfo workflow: the workflow that we want to get the errors from
8887
:returns: The error code that appears most often for this workflow
8988
:rtype: int
9089
"""
9190

91+
errors = workflow.get_errors(True)
92+
errors_summed = defaultdict(int)
9293

93-
curs, _, allerrors, _ = check_session(session).info
94-
95-
num_errors = []
96-
97-
for errorcode in allerrors:
98-
output = curs.execute("SELECT SUM(numbererrors) FROM workflows WHERE "
99-
"stepname LIKE '/{0}/%' AND errorcode={1}".\
100-
format(workflow, errorcode))
94+
for codes in errors.values():
95+
for errorcode, sites in codes.items():
96+
numcode = -1 if errorcode == 'NotReported' else int(errorcode)
97+
for num in sites.values():
98+
errors_summed[numcode] += num
10199

102-
num_errors.append(output[0])
100+
output = 0
101+
max_num = 0
102+
for code, num in errors_summed.items():
103+
if num > max_num:
104+
max_num = num
105+
output = code
103106

104-
return allerrors[num_errors.index(max(num_errors))]
107+
return output

workflowwebtools/default/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ cluster:
8181
cache_refresh:
8282
errors: 345600
8383
workspace: '.'
84+
refresh_period: 15

workflowwebtools/errorutils.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,33 @@
1414

1515
import validators
1616
import cherrypy
17+
import cx_Oracle
1718

1819
from cmstoolbox import sitereadiness
1920
from cmstoolbox.webtools import get_json
2021

2122
from . import workflowinfo
2223
from . import serverconfig
2324

25+
def errors_from_list(workflows):
26+
"""
27+
:param list workflows: A list of workflows that are in assistance-manual
28+
:returns: The errors for the workflows
29+
:rtype: dict
30+
"""
31+
indict = {}
32+
33+
for workflow in workflows:
34+
base = workflowinfo.WorkflowInfo(workflow)
35+
prep_id = base.get_prep_id()
36+
for wkf in set(workflowinfo.PrepIDInfo(prep_id).get_workflows()):
37+
indict.update(
38+
workflowinfo.WorkflowInfo(wkf).get_errors(get_unreported=True)
39+
)
40+
41+
return indict
42+
43+
2444
def open_location(data_location):
2545
"""
2646
This function assumes that the contents of the location is in JSON format.
@@ -30,8 +50,19 @@ def open_location(data_location):
3050
:returns: information in the JSON file
3151
:rtype: dict
3252
"""
53+
config_dict = serverconfig.config_dict()
54+
55+
if 'oracle' in config_dict:
56+
oracle_db_conn = cx_Oracle.connect(*config_dict['oracle']) # pylint:disable=c-extension-no-member
57+
oracle_cursor = oracle_db_conn.cursor()
58+
oracle_cursor.execute(
59+
"SELECT NAME FROM CMS_UNIFIED_ADMIN.workflow WHERE lower(STATUS) LIKE '%manual%'")
60+
wkfs = [row for row, in oracle_cursor]
61+
oracle_db_conn.close()
62+
cherrypy.log('Number of workflows from database: %i' % len(wkfs))
63+
return errors_from_list(wkfs)
64+
3365
raw = None
34-
indict = {}
3566

3667
if os.path.isfile(data_location):
3768
with open(data_location, 'r') as input_file:
@@ -41,7 +72,7 @@ def open_location(data_location):
4172
components = urlparse.urlparse(data_location)
4273

4374
# Anything we need for the Shibboleth cookie could be in the config file
44-
cookie_stuff = serverconfig.config_dict()['data']
75+
cookie_stuff = config_dict['data']
4576

4677
raw = get_json(components.netloc, components.path,
4778
use_https=True,
@@ -56,16 +87,10 @@ def open_location(data_location):
5687
if not (keys and isinstance(raw[keys[0]], list)):
5788
return raw
5889

59-
for workflow, statuses in raw.iteritems():
60-
if True in ['manual' in status for status in statuses]:
61-
base = workflowinfo.WorkflowInfo(workflow)
62-
prep_id = base.get_prep_id()
63-
for wkf in set(workflowinfo.PrepIDInfo(prep_id).get_workflows()):
64-
indict.update(
65-
workflowinfo.WorkflowInfo(wkf).get_errors(get_unreported=True)
66-
)
67-
68-
return indict
90+
return errors_from_list([
91+
workflow for workflow, statuses in raw.iteritems()
92+
if True in ['manual' in status for status in statuses]
93+
])
6994

7095

7196
def get_list_info(status_list):

workflowwebtools/globalerrors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,8 @@ def check_session(session, can_refresh=False):
357357
GLOBAL_LOCK.release()
358358

359359
# If session ErrorInfo is old, set up another connection
360-
if can_refresh and theinfo.timestamp < time.time() - 60*30:
360+
if can_refresh and theinfo.timestamp < time.time() - \
361+
60*serverconfig.config_dict()['refresh_period']:
361362
theinfo.teardown()
362363
theinfo.setup()
363364

Lines changed: 175 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,189 @@
1+
# pylint: disable=missing-docstring, too-complex, invalid-name, too-many-branches, too-many-locals
2+
13
"""
24
A module that evaluates a model and returns the prediction
35
"""
46

5-
#from cmstoolbox import sitereadiness
7+
import os
8+
import random
9+
import itertools
10+
11+
import numpy as np
12+
import pandas as pd
13+
import keras as K
14+
15+
16+
def modified_site_name(site):
17+
site_name = site.split('_')[:-1]
18+
s = ''
19+
for i in site_name:
20+
s = s+i+'_'
21+
s = s.rstrip('_')
22+
return s
23+
24+
def build_table(df, template_table):
25+
sparse_df = template_table.copy()
26+
27+
tier0_sites, tier1_sites, tier2_sites, tier3_sites = [], [], [], []
28+
for i in sparse_df.keys():
29+
if i != 'NA':
30+
if i[1] == '0':
31+
tier0_sites.append(i)
32+
elif i[1] == '1':
33+
tier1_sites.append(i)
34+
elif i[1] == '2':
35+
tier2_sites.append(i)
36+
elif i[1] == '3':
37+
tier3_sites.append(i)
38+
39+
n0, n1, n2, n3 = len(tier0_sites), len(tier1_sites), len(tier2_sites), len(tier3_sites)
40+
for exit_code, site_dict in zip(df.keys(), df.values()):
41+
exit_code = int(exit_code)
42+
for site, count in site_dict.items():
43+
44+
chosen_site = None
45+
site_present_in_training_data = site in sparse_df.keys()
46+
if not site_present_in_training_data:
47+
48+
site = modified_site_name(site)
49+
cond = site in sparse_df.keys()
50+
if cond:
51+
chosen_site = site
52+
53+
print "Detected a site %s which was not present in the training dataset" % site
54+
print "We would use a proxy site for this based on whether it is T1, T2 or T3"
55+
tier = site.split("_")[0][1]
56+
if chosen_site is None:
57+
if tier == '1':
58+
chosen_num = random.randint(0, n1-1)
59+
chosen_site = tier1_sites[chosen_num]
60+
elif tier == '2':
61+
chosen_num = random.randint(0, n2-1)
62+
chosen_site = tier2_sites[chosen_num]
63+
print 'The chosen site is ', chosen_site
64+
elif tier == '3':
65+
chosen_num = random.randint(0, n3-1)
66+
chosen_site = tier3_sites[chosen_num]
67+
print 'The chosen site is ', chosen_site
68+
elif tier == '0':
69+
chosen_num = random.randint(0, n0-1)
70+
chosen_site = tier0_sites[chosen_num]
71+
print 'The chosen site is ', chosen_site
72+
else:
73+
continue
74+
if chosen_site is None:
75+
chosen_site = site
76+
77+
if np.isnan(count) or np.isnan(sparse_df.loc[exit_code, chosen_site]):
78+
sparse_df.loc[exit_code, chosen_site] = 0
79+
else:
80+
sparse_df.loc[exit_code, chosen_site] = count
81+
82+
return sparse_df
83+
84+
85+
def list_of_sites(x):
86+
return [item.keys() for item in x] or ['NA']
87+
88+
def build_table_flatten(x):
89+
d_outer = []
90+
91+
for column in x:
92+
93+
for item in x[column]:
94+
d_outer.append(item)
95+
96+
return d_outer
97+
98+
99+
def pred(errors):
100+
# Needs all of these files to be local
101+
for filename in ['sparse_table.csv', 'actionfile.txt', 'my_model.h5']:
102+
if not os.path.exists(filename):
103+
return ['TBD']
104+
105+
df = pd.DataFrame(columns=('workflow', 'errors'))
106+
base_data = []
107+
for i in errors:
108+
it = i.items()
109+
base_data.extend(it)
110+
111+
for i, dat in enumerate(base_data):
112+
workflow, error_dict = dat[0], dat[1]
113+
114+
if 'NotReported' in error_dict:
115+
error_dict[-1] = error_dict.pop('NotReported')
116+
117+
df.loc[i] = [workflow, error_dict]
118+
119+
template_table = pd.read_csv("sparse_table.csv").set_index("Unnamed: 0")
120+
template_table[:] = 0
121+
df['errors_sites_exit_codes'] = df['errors'].apply(lambda x: x.keys() if x else ['0'])
122+
123+
df['errors_sites_dict'] = df['errors'].apply(lambda x: x.values() if x else [{'NA': 0}])
124+
125+
df['errors_sites_list'] = df['errors_sites_dict'].apply(list_of_sites)
126+
127+
list2d = df['errors_sites_exit_codes'].tolist()
128+
129+
sites_exit_codes = sorted(set(list(itertools.chain.from_iterable(list2d))), key=int)
130+
sites_exit_codes = [str(x) for x in sites_exit_codes]
131+
132+
list2d_step1 = df['errors_sites_list'].tolist()
133+
list2d_step2 = list(itertools.chain.from_iterable(list2d_step1))
134+
site_names = sorted(set(list(itertools.chain.from_iterable(list2d_step2))))
135+
site_names = [str(x) for x in site_names]
136+
137+
df['table_sites'] = df['errors'].apply(
138+
lambda x: build_table(x, template_table))
139+
df['table_sites_flatten'] = df['table_sites'].apply(build_table_flatten)
140+
x_dataframe = df.loc[:, "table_sites_flatten"]
141+
x_matrix = x_dataframe.values
142+
feature_size = len(x_matrix[0])
143+
res = []
144+
clip_length = len(x_matrix[0])
145+
for i in x_matrix:
146+
i_clipped = i[:clip_length]
147+
res.extend(i_clipped)
148+
149+
res = np.asarray(res).reshape(-1, feature_size)
150+
mask = ~np.any(pd.isnull(res), axis=1)
151+
res = res[mask]
152+
153+
model = K.models.load_model('my_model.h5')
154+
predicted_actions_encoded = model.predict(np.array(np.asfarray(res)))
155+
predicted_actions_encoded = np.round(predicted_actions_encoded)
156+
157+
action_code_dictionary = {}
158+
a = np.genfromtxt("actionfile.txt", delimiter='\t', dtype=str)
159+
b = list(i.split(' ') for i in a)
160+
for i in b:
161+
162+
action_code_dictionary[int(i[1])] = i[0]
163+
164+
predicted_actions = []
165+
for i in predicted_actions_encoded:
166+
pos = np.argmax(i)
167+
168+
if pos in action_code_dictionary:
169+
predicted_actions.append(action_code_dictionary[pos])
170+
else:
171+
predicted_actions.append(-1)
172+
173+
K.backend.clear_session()
6174

175+
return predicted_actions
7176

8177

9-
def predict(errors): # pylint: disable=unused-argument
178+
def predict(wf_obj):
10179
"""
11180
Takes the errors for a workflow and makes an action prediction
12-
:param workflowwebtool.workflowinfo.WorkflowInfo errors:
181+
:param workflowwebtool.workflowinfo.WorkflowInfo wf_obj:
13182
The WorkflowInfo object that we want to perform a prediction on
14183
:returns: Prediction results to be passed back to a browser
15184
:rtype: dict
16185
"""
17186

18-
return {'Action': 'TBD'}
187+
return {
188+
'Action': pred([wf_obj.get_errors(True)])[0]
189+
}

0 commit comments

Comments
 (0)