Skip to content

Commit 165b7c0

Browse files
committed
Build out tests
1 parent d413536 commit 165b7c0

4 files changed

Lines changed: 331 additions & 23 deletions

File tree

src/scmdata/plotting.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def plumeplot( # pragma: no cover
158158
159159
quantile_over : str, tuple[str]
160160
Columns of ``self.meta`` over which the quantiles should be calculated.
161-
Only used if ``pre_calculated`` is ``True``.
161+
Only used if ``pre_calculated`` is ``False``.
162162
163163
Returns
164164
-------
@@ -236,15 +236,24 @@ def plumeplot( # pragma: no cover
236236
style_value = hsdf.get_unique_meta(style_var, no_duplicates=True)
237237

238238
xaxis = hsdf.timeseries(time_axis=time_axis).columns.tolist()
239-
if hue_value in _palette:
239+
if palette is not None:
240+
try:
241+
pkwargs["color"] = _palette[hue_value]
242+
except KeyError as exc:
243+
error_msg = "{} not in palette: {}".format(
244+
hue_value, palette
245+
)
246+
raise KeyError(error_msg) from exc
247+
248+
elif hue_value in _palette:
240249
pkwargs["color"] = _palette[hue_value]
241250

242251
if len(q) == 2:
243252
label = "{:.0f}th - {:.0f}th".format(q[0] * 100, q[1] * 100)
244253
p = ax.fill_between(
245254
xaxis,
246-
hsdf.filter(quantile=q[0]).values.squeeze(),
247-
hsdf.filter(quantile=q[1]).values.squeeze(),
255+
_get_1d_or_raise(hsdf.filter(quantile=q[0]), hue_var, style_var),
256+
_get_1d_or_raise(hsdf.filter(quantile=q[1]), hue_var, style_var),
248257
label=label,
249258
**pkwargs
250259
)
@@ -255,8 +264,14 @@ def plumeplot( # pragma: no cover
255264
elif len(q) == 1:
256265
_plotted_lines = True
257266

258-
if style_value in _dashes:
259-
pkwargs["linestyle"] = _dashes[style_value]
267+
if dashes is not None:
268+
try:
269+
pkwargs["linestyle"] = _dashes[style_value]
270+
except KeyError as exc:
271+
error_msg = "{} not in dashes: {}".format(
272+
style_value, dashes
273+
)
274+
raise KeyError(error_msg) from exc
260275
else:
261276
_dashes[style_value] = next(linestyle_cycler)
262277
pkwargs["linestyle"] = _dashes[style_value]
@@ -268,7 +283,7 @@ def plumeplot( # pragma: no cover
268283

269284
p = ax.plot(
270285
xaxis,
271-
hsdf.filter(quantile=q[0]).values.squeeze(),
286+
_get_1d_or_raise(hsdf.filter(quantile=q[0]), hue_var, style_var),
272287
label=label,
273288
linewidth=linewidth,
274289
**pkwargs
@@ -288,7 +303,7 @@ def plumeplot( # pragma: no cover
288303

289304
# Fake the line handles for the legend
290305
hue_val_lines = [
291-
mlines.Line2D([0], [0], **{"color": _palette[hue_value]}, label=hue_value)
306+
mlines.Line2D([0], [0], color=_palette[hue_value], label=hue_value)
292307
for hue_value in self.get_unique_meta(hue_var)
293308
]
294309

@@ -304,9 +319,10 @@ def plumeplot( # pragma: no cover
304319
mlines.Line2D(
305320
[0],
306321
[0],
307-
**{"linestyle": _dashes[style_value]},
322+
linestyle=_dashes[style_value],
308323
label=style_value,
309-
color="gray"
324+
color="gray",
325+
linewidth=linewidth,
310326
)
311327
for style_value in self.get_unique_meta(style_var)
312328
]
@@ -329,6 +345,33 @@ def plumeplot( # pragma: no cover
329345
return ax, legend_items
330346

331347

348+
def _get_1d_or_raise(in_scmrun, hue_var, style_var):
349+
out_arr = in_scmrun.values.squeeze()
350+
if len(out_arr.shape) > 1:
351+
quantile = in_scmrun.get_unique_meta("quantile", True)
352+
hue_var_value = in_scmrun.get_unique_meta(hue_var, True)
353+
style_var_value = in_scmrun.get_unique_meta(style_var, True)
354+
error_msg = (
355+
"More than one timeseries for "
356+
"quantile: {}, "
357+
"{}: {}, "
358+
"{}: {}.\n"
359+
"Please process your data to create unique quantile timeseries "
360+
"before calling :meth:`plumeplot`.\n"
361+
"Found: {}".format(
362+
quantile,
363+
hue_var,
364+
hue_var_value,
365+
style_var,
366+
style_var_value,
367+
in_scmrun,
368+
)
369+
)
370+
raise ValueError(error_msg)
371+
372+
return out_arr
373+
374+
332375
def _deprecated_line_plot(self, **kwargs): # pragma: no cover
333376
"""
334377
Make a line plot via `seaborn's lineplot <https://seaborn.pydata.org/generated/seaborn.lineplot.html>`_

tests/conftest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,14 @@ def test_processing_scm_df():
329329

330330
@pytest.fixture(scope="function")
331331
def plumeplot_scmrun():
332+
n_ems = 30
332333
yield ScmRun(
333-
data=np.random.random((10, 3)).T,
334+
data=np.random.random((n_ems * 2, 3)).T,
334335
columns={
335336
"model": ["a_iam"],
336-
"climate_model": ["a_model"] * 5 + ["a_model_2"] * 5,
337-
"scenario": ["a_scenario"] * 5 + ["a_scenario_2"] * 5,
338-
"ensemble_member": list(range(5)) + list(range(5)),
337+
"climate_model": ["a_model"] * n_ems + ["a_model_2"] * n_ems,
338+
"scenario": ["a_scenario"] * n_ems + ["a_scenario_2"] * n_ems,
339+
"ensemble_member": list(range(n_ems)) + list(range(n_ems)),
339340
"region": ["World"],
340341
"variable": ["Surface Air Temperature Change"],
341342
"unit": ["K"],

0 commit comments

Comments
 (0)