@@ -14,8 +14,6 @@ def despine(ax_or_axes):
1414 sns .despine (ax = ax )
1515
1616
17-
18-
1917def clean_axis (ax ):
2018 for loc in ("top" , "right" , "left" , "bottom" ):
2119 ax .spines [loc ].set_visible (False )
@@ -30,19 +28,19 @@ def plot_study_legend(study, case_keys=None, ax=None):
3028 Make a ax with only legend
3129 """
3230 import matplotlib .pyplot as plt
33-
31+
3432 if case_keys is None :
3533 case_keys = list (study .cases .keys ())
36-
34+
3735 if ax is None :
3836 fig , ax = plt .subplots (figsize = figsize )
3937 else :
4038 fig = ax .get_figure ()
4139
4240 colors = study .get_colors ()
43-
41+
4442 for k in case_keys :
45- ax .plot ([], color = colors [k ], label = study .cases [k ][' label' ])
43+ ax .plot ([], color = colors [k ], label = study .cases [k ][" label" ])
4644 ax .legend ()
4745 clean_axis (ax )
4846 return fig
@@ -126,7 +124,6 @@ def plot_run_times(study, case_keys=None, levels_to_keep=None, figsize=None, ax=
126124 The resulting figure containing the plots
127125 """
128126 import matplotlib .pyplot as plt
129-
130127
131128 if case_keys is None :
132129 case_keys = list (study .cases .keys ())
@@ -138,7 +135,6 @@ def plot_run_times(study, case_keys=None, levels_to_keep=None, figsize=None, ax=
138135 else :
139136 fig = ax .get_figure ()
140137
141-
142138 if levels_to_keep is None :
143139 colors = study .get_colors ()
144140 labels = []
@@ -183,7 +179,6 @@ def plot_run_times(study, case_keys=None, levels_to_keep=None, figsize=None, ax=
183179 plt_fun = sns .barplot
184180 palette_keys = hues
185181
186-
187182 assert all (
188183 [key in colors for key in palette_keys ]
189184 ), f"colors must have a color for each palette key: { palette_keys } "
@@ -230,7 +225,6 @@ def plot_unit_counts(study, case_keys=None, levels_to_keep=None, colors=None, fi
230225 if case_keys is None :
231226 case_keys = list (study .cases .keys ())
232227
233-
234228 if ax is None :
235229 fig , ax = plt .subplots (figsize = figsize )
236230 else :
@@ -294,7 +288,6 @@ def plot_unit_counts(study, case_keys=None, levels_to_keep=None, colors=None, fi
294288 else :
295289 assert all ([col in colors for col in columns ]), f"colors must have a color for each column: { columns } "
296290
297-
298291 df = pd .melt (
299292 count_units .reset_index (),
300293 id_vars = levels_to_keep ,
@@ -496,12 +489,11 @@ def plot_performances_vs_snr(
496489 if levels_to_keep is not None :
497490 case_group_keys , labels = study .get_grouped_keys_mapping (levels_to_group_by = levels_to_keep )
498491 else :
499- labels = {k : study .cases [k ]['label' ] for k in case_keys }
500- case_group_keys = {k : [k ] for k in case_keys }
501-
492+ labels = {k : study .cases [k ]["label" ] for k in case_keys }
493+ case_group_keys = {k : [k ] for k in case_keys }
502494
503495 colors = study .get_colors (levels_to_group_by = levels_to_keep )
504-
496+
505497 assert all ([key in colors for key in case_keys ]), f"colors must have a color for each case key: { case_keys } "
506498
507499 for key , key_list in case_group_keys .items ():
0 commit comments