Skip to content

Commit 93abaea

Browse files
authored
Merge pull request #244 from igerber/plotly-kwargs
Thread styling kwargs through plotly visualization backends
2 parents eacd38c + e5ee9ee commit 93abaea

6 files changed

Lines changed: 293 additions & 16 deletions

File tree

TODO.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ Deferred items from PR reviews that were not addressed before merge.
7171

7272
| Issue | Location | PR | Priority |
7373
|-------|----------|----|----------|
74-
| Plotly renderers silently ignore styling kwargs (marker, markersize, linewidth, capsize, ci_linewidth) that the matplotlib backend honors; thread them through or reject when `backend="plotly"` | `visualization/_event_study.py`, `_diagnostic.py`, `_power.py` | #222 | Medium |
7574
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
7675
| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low |
7776
| ~376 `duplicate object description` Sphinx warnings — restructure `docs/api/*.rst` to avoid duplicate `:members:` + `autosummary` | `docs/api/*.rst` || Low |

diff_diff/visualization/_common.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,42 @@ def _color_to_rgba(color, alpha=1.0):
283283
)
284284

285285

286+
# Matplotlib marker code -> plotly symbol name mapping
287+
_MPL_TO_PLOTLY_SYMBOL = {
288+
"o": "circle",
289+
"s": "square",
290+
"D": "diamond",
291+
"d": "diamond",
292+
"^": "triangle-up",
293+
"v": "triangle-down",
294+
"<": "triangle-left",
295+
">": "triangle-right",
296+
"p": "pentagon",
297+
"h": "hexagon",
298+
"+": "cross",
299+
"x": "x",
300+
"*": "star",
301+
".": "circle",
302+
}
303+
304+
305+
def _mpl_marker_to_plotly_symbol(marker):
306+
"""Convert a matplotlib marker code to a plotly symbol name.
307+
308+
Parameters
309+
----------
310+
marker : str
311+
Matplotlib marker shorthand (e.g., ``"o"``, ``"s"``, ``"D"``).
312+
313+
Returns
314+
-------
315+
str
316+
Plotly symbol name (e.g., ``"circle"``, ``"square"``, ``"diamond"``).
317+
Returns ``"circle"`` for unrecognized markers.
318+
"""
319+
return _MPL_TO_PLOTLY_SYMBOL.get(marker, "circle")
320+
321+
286322
# Default color constants
287323
DEFAULT_BLUE = "#2563eb"
288324
DEFAULT_RED = "#dc2626"

diff_diff/visualization/_diagnostic.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def plot_sensitivity(
111111
bounds_color=bounds_color,
112112
bounds_alpha=bounds_alpha,
113113
ci_color=ci_color,
114+
ci_linewidth=ci_linewidth,
114115
breakdown_color=breakdown_color,
115116
original_color=original_color,
116117
show=show,
@@ -242,6 +243,7 @@ def _render_sensitivity_plotly(
242243
bounds_color,
243244
bounds_alpha,
244245
ci_color,
246+
ci_linewidth,
245247
breakdown_color,
246248
original_color,
247249
show,
@@ -291,7 +293,7 @@ def _render_sensitivity_plotly(
291293
x=M_list,
292294
y=list(ci_arr[:, 0]),
293295
mode="lines",
294-
line=dict(color=ci_color, width=1.5),
296+
line=dict(color=ci_color, width=ci_linewidth),
295297
name="Robust CI",
296298
)
297299
)
@@ -300,7 +302,7 @@ def _render_sensitivity_plotly(
300302
x=M_list,
301303
y=list(ci_arr[:, 1]),
302304
mode="lines",
303-
line=dict(color=ci_color, width=1.5),
305+
line=dict(color=ci_color, width=ci_linewidth),
304306
showlegend=False,
305307
)
306308
)
@@ -449,6 +451,8 @@ def plot_bacon(
449451
xlabel=xlabel,
450452
ylabel=ylabel,
451453
colors=colors,
454+
marker=marker,
455+
markersize=markersize,
452456
alpha=alpha,
453457
show_weighted_avg=show_weighted_avg,
454458
show_twfe_line=show_twfe_line,
@@ -699,13 +703,19 @@ def _render_bacon_plotly(
699703
xlabel,
700704
ylabel,
701705
colors,
706+
marker,
707+
markersize,
702708
alpha,
703709
show_weighted_avg,
704710
show_twfe_line,
705711
show,
706712
):
707713
"""Render Bacon decomposition plot with plotly."""
708-
from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
714+
from diff_diff.visualization._common import (
715+
_mpl_marker_to_plotly_symbol,
716+
_plotly_default_layout,
717+
_require_plotly,
718+
)
709719

710720
go = _require_plotly()
711721

@@ -727,6 +737,10 @@ def _render_bacon_plotly(
727737
"later_vs_earlier": "Later vs Earlier (forbidden)",
728738
}
729739

740+
# Convert matplotlib scatter area (points^2) to plotly diameter (px)
741+
plotly_size = max(1, int(round(markersize**0.5)))
742+
symbol = _mpl_marker_to_plotly_symbol(marker)
743+
730744
for ctype, points in by_type.items():
731745
if not points:
732746
continue
@@ -737,7 +751,12 @@ def _render_bacon_plotly(
737751
x=estimates,
738752
y=weights,
739753
mode="markers",
740-
marker=dict(color=colors[ctype], size=10, opacity=alpha),
754+
marker=dict(
755+
color=colors[ctype],
756+
size=plotly_size,
757+
symbol=symbol,
758+
opacity=alpha,
759+
),
741760
name=labels[ctype],
742761
)
743762
)

diff_diff/visualization/_event_study.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def plot_event_study(
272272
xlabel=xlabel,
273273
ylabel=ylabel,
274274
color=color,
275+
marker=marker,
276+
markersize=markersize,
275277
shade_pre=shade_pre,
276278
shade_color=shade_color,
277279
show_zero_line=show_zero_line,
@@ -422,6 +424,8 @@ def _render_event_study_plotly(
422424
xlabel,
423425
ylabel,
424426
color,
427+
marker,
428+
markersize,
425429
shade_pre,
426430
shade_color,
427431
show_zero_line,
@@ -431,6 +435,7 @@ def _render_event_study_plotly(
431435
"""Render event study plot with plotly."""
432436
from diff_diff.visualization._common import (
433437
_color_to_rgba,
438+
_mpl_marker_to_plotly_symbol,
434439
_plotly_default_layout,
435440
_require_plotly,
436441
)
@@ -504,13 +509,15 @@ def _render_event_study_plotly(
504509

505510
hover_tpl = "Period: %{customdata}<br>Effect: %{y:.4f}<extra></extra>"
506511

512+
symbol = _mpl_marker_to_plotly_symbol(marker)
513+
507514
if non_ref_x:
508515
fig.add_trace(
509516
go.Scatter(
510517
x=non_ref_x,
511518
y=non_ref_e,
512519
mode="markers",
513-
marker=dict(color=color, size=10),
520+
marker=dict(color=color, size=markersize, symbol=symbol),
514521
name="Effect",
515522
customdata=non_ref_labels,
516523
hovertemplate=hover_tpl,
@@ -525,7 +532,8 @@ def _render_event_study_plotly(
525532
mode="markers",
526533
marker=dict(
527534
color="white",
528-
size=10,
535+
size=markersize,
536+
symbol=symbol,
529537
line=dict(color=color, width=2),
530538
),
531539
name="Reference",
@@ -842,6 +850,8 @@ def plot_honest_event_study(
842850
ylabel=ylabel,
843851
original_color=original_color,
844852
honest_color=honest_color,
853+
marker=marker,
854+
markersize=markersize,
845855
show=show,
846856
)
847857

@@ -987,11 +997,14 @@ def _render_honest_event_study_plotly(
987997
ylabel,
988998
original_color,
989999
honest_color,
1000+
marker,
1001+
markersize,
9901002
show,
9911003
):
9921004
"""Render honest event study plot with plotly."""
9931005
from diff_diff.visualization._common import (
9941006
_color_to_rgba,
1007+
_mpl_marker_to_plotly_symbol,
9951008
_plotly_default_layout,
9961009
_require_plotly,
9971010
)
@@ -1036,13 +1049,15 @@ def _render_honest_event_study_plotly(
10361049
ref_p = [p for p, r in zip(periods, is_ref) if r]
10371050
ref_e = [e for e, r in zip(effects, is_ref) if r]
10381051

1052+
symbol = _mpl_marker_to_plotly_symbol(marker)
1053+
10391054
if non_ref_p:
10401055
fig.add_trace(
10411056
go.Scatter(
10421057
x=non_ref_p,
10431058
y=non_ref_e,
10441059
mode="markers",
1045-
marker=dict(color=honest_color, size=10),
1060+
marker=dict(color=honest_color, size=markersize, symbol=symbol),
10461061
name="Effect",
10471062
)
10481063
)
@@ -1053,7 +1068,12 @@ def _render_honest_event_study_plotly(
10531068
x=ref_p,
10541069
y=ref_e,
10551070
mode="markers",
1056-
marker=dict(color="white", size=10, line=dict(color=honest_color, width=2)),
1071+
marker=dict(
1072+
color="white",
1073+
size=markersize,
1074+
symbol=symbol,
1075+
line=dict(color=honest_color, width=2),
1076+
),
10571077
name="Reference",
10581078
)
10591079
)

diff_diff/visualization/_power.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,10 @@ def plot_power_curve(
165165
color=color,
166166
mde_color=mde_color,
167167
target_color=target_color,
168+
linewidth=linewidth,
168169
show_mde_line=show_mde_line,
169170
show_target_line=show_target_line,
171+
show_grid=show_grid,
170172
show=show,
171173
)
172174

@@ -291,8 +293,10 @@ def _render_power_curve_plotly(
291293
color,
292294
mde_color,
293295
target_color,
296+
linewidth,
294297
show_mde_line,
295298
show_target_line,
299+
show_grid,
296300
show,
297301
):
298302
"""Render power curve with plotly."""
@@ -307,7 +311,7 @@ def _render_power_curve_plotly(
307311
x=effect_sizes,
308312
y=powers,
309313
mode="lines",
310-
line=dict(color=color, width=2),
314+
line=dict(color=color, width=linewidth),
311315
name="Power",
312316
)
313317
)
@@ -331,7 +335,8 @@ def _render_power_curve_plotly(
331335
)
332336

333337
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
334-
fig.update_yaxes(range=[0, 1.05], tickformat=".0%")
338+
fig.update_xaxes(showgrid=show_grid)
339+
fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
335340

336341
if show:
337342
fig.show()
@@ -482,8 +487,10 @@ def plot_pretrends_power(
482487
color=color,
483488
mdv_color=mdv_color,
484489
target_color=target_color,
490+
linewidth=linewidth,
485491
show_mdv_line=show_mdv_line,
486492
show_target_line=show_target_line,
493+
show_grid=show_grid,
487494
show=show,
488495
)
489496

@@ -602,8 +609,10 @@ def _render_pretrends_power_plotly(
602609
color,
603610
mdv_color,
604611
target_color,
612+
linewidth,
605613
show_mdv_line,
606614
show_target_line,
615+
show_grid,
607616
show,
608617
):
609618
"""Render pre-trends power curve with plotly."""
@@ -619,7 +628,7 @@ def _render_pretrends_power_plotly(
619628
x=M_values,
620629
y=powers,
621630
mode="lines",
622-
line=dict(color=color, width=2),
631+
line=dict(color=color, width=linewidth),
623632
name="Power",
624633
)
625634
)
@@ -643,7 +652,8 @@ def _render_pretrends_power_plotly(
643652
)
644653

645654
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
646-
fig.update_yaxes(range=[0, 1.05], tickformat=".0%")
655+
fig.update_xaxes(showgrid=show_grid)
656+
fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
647657

648658
if show:
649659
fig.show()

0 commit comments

Comments
 (0)