11from __future__ import annotations
22
33import argparse
4+ from math import sqrt
45from collections import defaultdict
56from datetime import datetime
67from glob import glob
78
89import matplotlib .pyplot as plt
9- import mplhep as hep # type: ignore
10+ import mplhep as hep
1011
1112
1213def parser_setup () -> argparse .ArgumentParser :
@@ -17,23 +18,15 @@ def parser_setup() -> argparse.ArgumentParser:
1718 "--channel" , "-c" , type = str , nargs = "*" , choices = ["Higgs" , "W" , "Z" , "all" ], default = None ,
1819 help = "Channel(s) to plot; defaults to Higgs and Z channels." ,
1920 )
21+ parser .add_argument ("--skip-w-ratio" , "-swr" , action = 'store_true' , default = False , help = "Calculate W+ to W- ratio; default is False" )
2022 parser .add_argument ("--min" , "-m" , type = float , default = 10.0 , help = "minimum value for the histogram; default is 10.0" )
2123 parser .add_argument ("--unstack" , "-u" , action = 'store_true' , default = False , help = "Unstack the histograms; default is stacked" )
2224 parser .add_argument ("--transverse-mass" , "-t" , action = 'store_true' , default = False , help = "Plot Transverse Mass; default is Invariant Mass" )
2325 parser .add_argument ("--n-bins" , "-n" , type = int , default = 20 , help = "Number of Bins; default is 20" )
2426 return parser
2527
2628
27- def get_files_by_channel (channel : str | list [str ] | None , input_folder : str ) -> dict [str , list [str ]]:
28- if not channel :
29- channels = ["Higgs" , "Z" ]
30- elif channel == "all" :
31- channels = ["Higgs" , "W" , "Z" ]
32- else :
33- channels = channel if isinstance (channel , list ) else [channel ]
34- if "W" in channels :
35- channels .extend (["Wp" , "Wm" ])
36- channels .remove ("W" )
29+ def get_files_by_channel (channels : list [str ], input_folder : str ) -> dict [str , list [str ]]:
3730 files = {}
3831 for ch in channels :
3932 ch_files = glob (f"{ input_folder } /{ ch } *.csv" )
@@ -44,37 +37,37 @@ def get_files_by_channel(channel: str | list[str] | None, input_folder: str) ->
4437 return files
4538
4639
47- class MassReader :
48- def __init__ (self , files : dict [str , list [str ]], n_bins : int , x_min : float , transverse_mass : bool = False ):
49- """ Reads the mass data from the files and stores it in a dictionary.
50- Args:
51- files (dict[str, list[str]]): A dictionary where the keys are the channel names and the values are lists of
52- csv file paths containing the mass data.
53- n_bins (int): The number of bins to use for the histogram.
54- x_min (float): The minimum value for the histogram.
55- transverse_mass (bool): Whether to read the transverse mass or the invariant mass. Default is False.
40+ def resolve_channels (channel : str | list | None ) -> list [str ]:
41+ if not channel :
42+ channels = ["Higgs" , "Z" ]
43+ elif channel == "all" :
44+ channels = ["Higgs" , "W" , "Z" ]
45+ else :
46+ channels = channel if isinstance (channel , list ) else [channel ]
47+ if "W" in channels :
48+ channels .extend (["Wp" , "Wm" ])
49+ channels .remove ("W" )
50+ return channels
5651
57- Raises:
58- ValueError: If the file format is invalid or if no data is found in the file.
59- """
60- # instance variables
61- self .data : dict [str , list [float ]] = defaultdict (list )
62- self .transverse_mass : bool = transverse_mass
6352
53+ class BaseReader :
54+ def __init__ (self , files : dict [str , list [str ]], column : str = "Invariant Mass" ):
55+ if not files :
56+ error_msg = "No files found for the chosen channels."
57+ if self .__class__ == WReader :
58+ error_msg += "\n Cannot calculate W+ to W- ratio."
59+ raise Exception (error_msg )
60+ self .data : dict [str , list [float ]] = defaultdict (list )
6461 # Read the data from the files
6562 for channel , file_list in files .items ():
63+ read_files = []
6664 if not file_list :
6765 continue
6866 for file in file_list :
69- self . data [ channel ]. append (self ._read_file (file ))
70- self .data [channel ] = sum (self . data [ channel ] , [])
67+ read_files . append (self ._read_file (file , column ))
68+ self .data [channel ] = sum (read_files , [])
7169
72- # Define the histogram bins
73- minimum = max (x_min , min (min (data_array ) for data_array in self .data .values ()))
74- maximum = max (max (data_array ) for data_array in self .data .values ())
75- self .bins : list [float ] = [minimum + i * (maximum - minimum ) / n_bins for i in range (n_bins + 1 )]
76-
77- def _read_file (self , file : str ) -> list [float ]:
70+ def _read_file (self , file : str , mass_label : str ) -> list [float ]:
7871 # check file path validity
7972 if not file .endswith (".csv" ):
8073 raise ValueError (f"Invalid file format: '{ file } '" )
@@ -83,7 +76,6 @@ def _read_file(self, file: str) -> list[float]:
8376 with open (file , 'r' ) as f :
8477 lines = f .readlines ()
8578 header = lines [0 ].strip ().split (',' )
86- mass_label = "Transverse Mass" if self .transverse_mass else "Invariant Mass"
8779 mass_index = header .index (mass_label )
8880 data = [float (line .strip ().split (',' )[mass_index ]) for line in lines [1 :]]
8981
@@ -92,23 +84,6 @@ def _read_file(self, file: str) -> list[float]:
9284 print (f"Skipping '{ file } '... No data found in file." )
9385 return data
9486
95- def w_ratio (self ):
96- """ Calculates the ratio of W+ to W- events.
97- """
98- wp = len (self .data ["Wp" ])
99- wm = len (self .data ["Wm" ])
100-
101- print (40 * "*" )
102- print ("***\t Calculating W+ to W- ratio..." )
103- if wm == 0 :
104- print ("!!!\t No W- events found. Skipping W+ to W- ratio calculation." )
105- else :
106- w_ratio = wp / wm
107- w_error = np .sqrt (w_ratio / wm * (1 + w_ratio ))
108- print (f"***\t Found: { wp } W+ events and { wm } W- events." )
109- print (f"***\t W+ to W- ratio: { w_ratio } ± { w_error } (stat)" )
110- print (40 * "*" )
111-
11287 def items (self ) -> tuple [list [str ], list [list [float ]]]:
11388 """ Returns the keys and values of the data dictionary as separate lists. The keys are modified to match the
11489 expected channel names.
@@ -122,53 +97,103 @@ def items(self) -> tuple[list[str], list[list[float]]]:
12297 return keys , list (self .data .values ())
12398
12499
125- def plot_masses (reader : MassReader , unstack : bool , output : str , ** kwargs ):
126- """ Plots the masses from the MassReader object.
127- Args:
128- reader (MassReader): The MassReader object containing the mass data.
129- unstack (bool): Whether to unstack the histograms.
130- output (str): The output file path.
131- **kwargs: Additional keyword arguments.
100+ class MassReader (BaseReader ):
101+ def __init__ (self , files : dict [str , list [str ]], n_bins : int , x_min : float , transverse_mass : bool = False ):
102+ """ Reads the mass data from the files and stores it in a dictionary.
103+ Args:
104+ files (dict[str, list[str]]): A dictionary where the keys are the channel names and the values are lists of
105+ csv file paths containing the mass data.
106+ n_bins (int): The number of bins to use for the histogram.
107+ x_min (float): The minimum value for the histogram.
108+ transverse_mass (bool): Whether to read the transverse mass or the invariant mass. Default is False.
109+
110+ Raises:
111+ ValueError: If the file format is invalid or if no data is found in the file.
112+ """
113+ # instance variables
114+ self .transverse_mass : bool = transverse_mass
115+ super (MassReader , self ).__init__ (files , "Transverse Mass" if transverse_mass else "Invariant Mass" )
116+
117+ # Define the histogram bins
118+ minimum = max (x_min , min (min (data_array ) for data_array in self .data .values ()))
119+ maximum = max (max (data_array ) for data_array in self .data .values ())
120+ self .bins : list [float ] = [minimum + i * (maximum - minimum ) / n_bins for i in range (n_bins + 1 )]
121+
122+ def plot_masses (self : MassReader , unstack : bool , output : str , ** kwargs ):
123+ """ Plots the masses from the MassReader object.
124+ Args:
125+ reader (MassReader): The MassReader object containing the mass data.
126+ unstack (bool): Whether to unstack the histograms.
127+ output (str): The output file path.
128+ **kwargs: Additional keyword arguments.
129+ """
130+ # Set the style
131+ plt .style .use (hep .style .CMS )
132+ # initialize the figure and setup axis
133+ _ , ax = plt .subplots (dpi = 100 , figsize = (8 , 8 ))
134+ hep .cms .label (data = False , rlabel = "Masterclass" , ax = ax )
135+ ax .set_xlabel ("Transverse Mass [GeV]" if self .transverse_mass else "Invariant Mass [GeV]" )
136+ ax .set_ylabel ("Events" )
137+ # create a config dict
138+ hist_kwargs = {
139+ "bins" : self .bins ,
140+ "stacked" : not unstack ,
141+ "histtype" : "stepfilled" if not unstack else "step" ,
142+ }
143+ # plot the data
144+ channels , masses = self .items ()
145+ ax .hist (masses , ** hist_kwargs , label = channels )
146+ # add legend and grid
147+ ax .legend ()
148+ ax .grid ()
149+ # save the figure
150+ plt .savefig (output , dpi = 300 , bbox_inches = "tight" )
151+ print (f"Saved plot to '{ output } '" )
152+ # show the plot
153+ plt .show ()
154+
155+
156+ class WReader (BaseReader ):
157+ """ Calculates the ratio of W+ to W- events.
132158 """
133- # Set the style
134- plt .style .use (hep .style .CMS )
135- # initialize the figure and setup axis
136- _ , ax = plt .subplots (dpi = 100 , figsize = (8 , 8 ))
137- hep .cms .label (data = False , rlabel = "Masterclass" , ax = ax )
138- ax .set_xlabel ("Transverse Mass [GeV]" if reader .transverse_mass else "Invariant Mass [GeV]" )
139- ax .set_ylabel ("Events" )
140- # create a config dict
141- hist_kwargs = {
142- "bins" : reader .bins ,
143- "stacked" : not unstack ,
144- "histtype" : "stepfilled" if not unstack else "step" ,
145- }
146- # plot the data
147- channels , masses = reader .items ()
148- ax .hist (masses , ** hist_kwargs , label = channels )
149- # add legend and grid
150- ax .legend ()
151- ax .grid ()
152- # save the figure
153- plt .savefig (output , dpi = 300 , bbox_inches = "tight" )
154- print (f"Saved plot to '{ output } '" )
155- # show the plot
156- plt .show ()
159+ def __init__ (self , files : dict [str , list [str ]]):
160+
161+ super (WReader , self ).__init__ (files )
162+
163+ self .wp = len (self .data ["Wp" ])
164+ self .wm = len (self .data ["Wm" ])
165+
166+ if self .wm == 0 :
167+ print ("!!!\t No W- events found. Skipping W+ to W- ratio calculation." )
168+ else :
169+ self .w_ratio = self .wp / self .wm
170+ self .w_error = sqrt (self .w_ratio / self .wm * (1 + self .w_ratio ))
171+
172+ def print_ratio (self ):
173+ print (40 * "*" )
174+ print ("***\t Calculating W+ to W- ratio..." )
175+ print (f"***\t Found: { self .wp } W+ events and { self .wm } W- events." )
176+ print (f"***\t W+ to W- ratio: { self .w_ratio :.5f} ± { self .w_error :.5f} (stat)" )
177+ print (40 * "*" )
157178
158179
159180def main ():
160181 parser = parser_setup ()
161182 args = parser .parse_args ()
162183 if args .output is None :
163184 args .output = f"./{ datetime .now ().strftime ('%d%m%Y' )} .png"
164- files = get_files_by_channel (args .channel , args .input )
165- if not files :
166- raise Exception (f"No files found for channel '{ args .channel } ' in folder '{ args .input } '" )
185+ # Handle channels
186+ channels = resolve_channels (args .channel )
187+
188+ if not args .skip_w_ratio :
189+ w_files = get_files_by_channel (["Wp" , "Wm" ], args .input )
190+ w_reader = WReader (files = w_files )
191+ w_reader .print_ratio ()
192+ input ("Press Enter to continue to the plot..." )
167193
194+ files = get_files_by_channel (channels , args .input )
168195 reader = MassReader (files = files , n_bins = args .n_bins , x_min = args .min , transverse_mass = args .transverse_mass )
169- reader .w_ratio ()
170- input ("Press Enter to continue to the plot..." )
171- plot_masses (reader , ** args .__dict__ )
196+ reader .plot_masses (** args .__dict__ )
172197
173198
174199if __name__ == "__main__" :
0 commit comments