2424import functools as ft
2525
2626import numpy as np
27+ from scipy .stats import gaussian_kde
2728
2829from mpl_toolkits .mplot3d import Axes3D
2930import matplotlib .pyplot as plt
@@ -759,8 +760,6 @@ def plotfit_steady(
759760
760761def plotparainteract (result , paranames , plotname = None , fig = None , style = "WTP" ):
761762 """Plot of parameter interaction."""
762- import pandas as pd
763-
764763 style = copy .deepcopy (plt .rcParams ) if style is None else style
765764 keep_fs = False
766765 if style == "WTP" :
@@ -775,21 +774,9 @@ def plotparainteract(result, paranames, plotname=None, fig=None, style="WTP"):
775774 # font type fix (resetted in default)
776775 plt .rcParams .update ({"pdf.fonttype" : pdf_ft , "ps.fonttype" : ps_ft })
777776 plt .rcParams .update ({"font.size" : font_size })
778- fig , ax = _get_fig_ax (fig , ax = None , figsize = (12 , 12 ))
779777 fields = [par for par in result .dtype .names if par .startswith ("par" )]
780- parameterdistribtion = result [fields ]
781- df = pd .DataFrame (
782- np .asarray (parameterdistribtion ).T .tolist (), columns = paranames
783- )
784- with warnings .catch_warnings ():
785- # We know that fig is resetted, but we need to give ax to set fig
786- warnings .simplefilter ("ignore" , UserWarning )
787- if len (paranames ) > 1 :
788- pd .plotting .scatter_matrix (
789- df , alpha = 0.2 , ax = ax , diagonal = "kde"
790- )
791- else :
792- df .plot .kde (ax = ax )
778+ para = [result [:][name ] for name in fields ]
779+ fig = _scatter_matrix (para , paranames , fig )
793780 fig .tight_layout ()
794781 fig .subplots_adjust (hspace = 0 , wspace = 0 , bottom = 0.1 )
795782 if plotname is not None :
@@ -895,3 +882,55 @@ def plotsensitivity(
895882 bbox_inches = "tight" ,
896883 )
897884 return ax
885+
886+
887+ def _scatter_matrix (data , label , fig = None ):
888+ data = np .array (data , ndmin = 2 , dtype = float )
889+ n = len (data )
890+ axes = np .empty (n ** 2 , dtype = object )
891+ for i in range (n ** 2 ):
892+ fig , axes [i ] = _get_fig_ax (fig , figsize = (8 , 8 ), sub_args = (n , n , i + 1 ))
893+ axes = axes .reshape (n , n )
894+
895+ boundaries_list = []
896+ for dat in data :
897+ rmin , rmax = np .min (dat ), np .max (dat )
898+ rdelta = (rmax - rmin ) * 0.025
899+ boundaries_list .append ((rmin - rdelta , rmax + rdelta ))
900+
901+ for i , a in enumerate (data ):
902+ for j , b in enumerate (data ):
903+ ax = axes [i , j ]
904+ if i == j :
905+ ind = np .linspace (a .min (), a .max (), 1000 )
906+ ax .plot (ind , gaussian_kde (a ).evaluate (ind ))
907+ else :
908+ ax .scatter (b , a , marker = "." , alpha = 0.2 , edgecolors = "none" )
909+ ax .set_ylim (boundaries_list [i ])
910+ ax .set_xlim (boundaries_list [j ])
911+ ax .set_xlabel (label [j ])
912+ ax .set_ylabel (label [i ])
913+ if j != 0 :
914+ ax .yaxis .set_visible (False )
915+ if i != n - 1 :
916+ ax .xaxis .set_visible (False )
917+
918+ # reset labels of first kde plot to match scatter plots
919+ if n > 1 :
920+ lim1 = boundaries_list [0 ]
921+ locs = axes [0 , 1 ].yaxis .get_majorticklocs ()
922+ locs = locs [(lim1 [0 ] <= locs ) & (locs <= lim1 [1 ])]
923+ adj = (locs - lim1 [0 ]) / (lim1 [1 ] - lim1 [0 ])
924+
925+ lim0 = axes [0 , 0 ].get_ylim ()
926+ adj = adj * (lim0 [1 ] - lim0 [0 ]) + lim0 [0 ]
927+ axes [0 , 0 ].yaxis .set_ticks (adj )
928+ locs = locs .astype (int ) if np .all (locs == locs .astype (int )) else locs
929+ axes [0 , 0 ].yaxis .set_ticklabels (locs )
930+ fig .align_ylabels (axes [:, 0 ])
931+ fig .align_xlabels (axes [- 1 , :])
932+
933+ for ax in axes [- 1 , :]:
934+ plt .setp (ax .get_xticklabels (), rotation = 90 )
935+
936+ return fig
0 commit comments