Skip to content

Commit 38b3782

Browse files
authored
Improve plot script (#13)
* refactor * clip to 5 numbers after comma
1 parent 020840b commit 38b3782

2 files changed

Lines changed: 115 additions & 89 deletions

File tree

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"three": "^0.130.1",
5353
"typescript": "^5.8.3",
5454
"typescript-transform-paths": "^3.5.5",
55+
"http-server": "^14.1.1",
5556
"webpack": "^5.99.9"
5657
}
5758
}

scripts/plotting_script.py

Lines changed: 114 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

33
import argparse
4+
from math import sqrt
45
from collections import defaultdict
56
from datetime import datetime
67
from glob import glob
78

89
import matplotlib.pyplot as plt
9-
import mplhep as hep # type: ignore
10+
import mplhep as hep
1011

1112

1213
def parser_setup() -> argparse.ArgumentParser:
@@ -17,23 +18,15 @@ def parser_setup() -> argparse.ArgumentParser:
1718
"--channel", "-c", type=str, nargs="*", choices=["Higgs", "W", "Z", "all"], default=None,
1819
help="Channel(s) to plot; defaults to Higgs and Z channels.",
1920
)
21+
parser.add_argument("--skip-w-ratio", "-swr", action='store_true', default=False, help="Calculate W+ to W- ratio; default is False")
2022
parser.add_argument("--min", "-m", type=float, default=10.0, help="minimum value for the histogram; default is 10.0")
2123
parser.add_argument("--unstack", "-u", action='store_true', default=False, help="Unstack the histograms; default is stacked")
2224
parser.add_argument("--transverse-mass", "-t", action='store_true', default=False, help="Plot Transverse Mass; default is Invariant Mass")
2325
parser.add_argument("--n-bins", "-n", type=int, default=20, help="Number of Bins; default is 20")
2426
return parser
2527

2628

27-
def get_files_by_channel(channel: str | list[str] | None, input_folder: str) -> dict[str, list[str]]:
28-
if not channel:
29-
channels = ["Higgs", "Z"]
30-
elif channel == "all":
31-
channels = ["Higgs", "W", "Z"]
32-
else:
33-
channels = channel if isinstance(channel, list) else [channel]
34-
if "W" in channels:
35-
channels.extend(["Wp", "Wm"])
36-
channels.remove("W")
29+
def get_files_by_channel(channels: list[str], input_folder: str) -> dict[str, list[str]]:
3730
files = {}
3831
for ch in channels:
3932
ch_files = glob(f"{input_folder}/{ch}*.csv")
@@ -44,37 +37,37 @@ def get_files_by_channel(channel: str | list[str] | None, input_folder: str) ->
4437
return files
4538

4639

47-
class MassReader:
48-
def __init__(self, files: dict[str, list[str]], n_bins: int, x_min: float, transverse_mass: bool = False):
49-
""" Reads the mass data from the files and stores it in a dictionary.
50-
Args:
51-
files (dict[str, list[str]]): A dictionary where the keys are the channel names and the values are lists of
52-
csv file paths containing the mass data.
53-
n_bins (int): The number of bins to use for the histogram.
54-
x_min (float): The minimum value for the histogram.
55-
transverse_mass (bool): Whether to read the transverse mass or the invariant mass. Default is False.
40+
def resolve_channels(channel: str | list | None) -> list[str]:
41+
if not channel:
42+
channels = ["Higgs", "Z"]
43+
elif channel == "all":
44+
channels = ["Higgs", "W", "Z"]
45+
else:
46+
channels = channel if isinstance(channel, list) else [channel]
47+
if "W" in channels:
48+
channels.extend(["Wp", "Wm"])
49+
channels.remove("W")
50+
return channels
5651

57-
Raises:
58-
ValueError: If the file format is invalid or if no data is found in the file.
59-
"""
60-
# instance variables
61-
self.data: dict[str, list[float]] = defaultdict(list)
62-
self.transverse_mass: bool = transverse_mass
6352

53+
class BaseReader:
54+
def __init__(self, files: dict[str, list[str]], column: str = "Invariant Mass"):
55+
if not files:
56+
error_msg = "No files found for the chosen channels."
57+
if self.__class__ == WReader:
58+
error_msg += "\nCannot calculate W+ to W- ratio."
59+
raise Exception(error_msg)
60+
self.data: dict[str, list[float]] = defaultdict(list)
6461
# Read the data from the files
6562
for channel, file_list in files.items():
63+
read_files = []
6664
if not file_list:
6765
continue
6866
for file in file_list:
69-
self.data[channel].append(self._read_file(file))
70-
self.data[channel] = sum(self.data[channel], [])
67+
read_files.append(self._read_file(file, column))
68+
self.data[channel] = sum(read_files, [])
7169

72-
# Define the histogram bins
73-
minimum = max(x_min, min(min(data_array) for data_array in self.data.values()))
74-
maximum = max(max(data_array) for data_array in self.data.values())
75-
self.bins: list[float] = [minimum + i * (maximum - minimum) / n_bins for i in range(n_bins + 1)]
76-
77-
def _read_file(self, file: str) -> list[float]:
70+
def _read_file(self, file: str, mass_label: str) -> list[float]:
7871
# check file path validity
7972
if not file.endswith(".csv"):
8073
raise ValueError(f"Invalid file format: '{file}'")
@@ -83,7 +76,6 @@ def _read_file(self, file: str) -> list[float]:
8376
with open(file, 'r') as f:
8477
lines = f.readlines()
8578
header = lines[0].strip().split(',')
86-
mass_label = "Transverse Mass" if self.transverse_mass else "Invariant Mass"
8779
mass_index = header.index(mass_label)
8880
data = [float(line.strip().split(',')[mass_index]) for line in lines[1:]]
8981

@@ -92,23 +84,6 @@ def _read_file(self, file: str) -> list[float]:
9284
print(f"Skipping '{file}'... No data found in file.")
9385
return data
9486

95-
def w_ratio(self):
96-
""" Calculates the ratio of W+ to W- events.
97-
"""
98-
wp = len(self.data["Wp"])
99-
wm = len(self.data["Wm"])
100-
101-
print(40 * "*")
102-
print("***\tCalculating W+ to W- ratio...")
103-
if wm == 0:
104-
print("!!!\tNo W- events found. Skipping W+ to W- ratio calculation.")
105-
else:
106-
w_ratio = wp / wm
107-
w_error = np.sqrt(w_ratio / wm * (1 + w_ratio))
108-
print(f"***\tFound: {wp} W+ events and {wm} W- events.")
109-
print(f"***\tW+ to W- ratio: {w_ratio} ± {w_error} (stat)")
110-
print(40 * "*")
111-
11287
def items(self) -> tuple[list[str], list[list[float]]]:
11388
""" Returns the keys and values of the data dictionary as separate lists. The keys are modified to match the
11489
expected channel names.
@@ -122,53 +97,103 @@ def items(self) -> tuple[list[str], list[list[float]]]:
12297
return keys, list(self.data.values())
12398

12499

125-
def plot_masses(reader: MassReader, unstack: bool, output: str, **kwargs):
126-
""" Plots the masses from the MassReader object.
127-
Args:
128-
reader (MassReader): The MassReader object containing the mass data.
129-
unstack (bool): Whether to unstack the histograms.
130-
output (str): The output file path.
131-
**kwargs: Additional keyword arguments.
100+
class MassReader(BaseReader):
101+
def __init__(self, files: dict[str, list[str]], n_bins: int, x_min: float, transverse_mass: bool = False):
102+
""" Reads the mass data from the files and stores it in a dictionary.
103+
Args:
104+
files (dict[str, list[str]]): A dictionary where the keys are the channel names and the values are lists of
105+
csv file paths containing the mass data.
106+
n_bins (int): The number of bins to use for the histogram.
107+
x_min (float): The minimum value for the histogram.
108+
transverse_mass (bool): Whether to read the transverse mass or the invariant mass. Default is False.
109+
110+
Raises:
111+
ValueError: If the file format is invalid or if no data is found in the file.
112+
"""
113+
# instance variables
114+
self.transverse_mass: bool = transverse_mass
115+
super(MassReader, self).__init__(files, "Transverse Mass" if transverse_mass else "Invariant Mass")
116+
117+
# Define the histogram bins
118+
minimum = max(x_min, min(min(data_array) for data_array in self.data.values()))
119+
maximum = max(max(data_array) for data_array in self.data.values())
120+
self.bins: list[float] = [minimum + i * (maximum - minimum) / n_bins for i in range(n_bins + 1)]
121+
122+
def plot_masses(self: MassReader, unstack: bool, output: str, **kwargs):
123+
""" Plots the masses from the MassReader object.
124+
Args:
125+
reader (MassReader): The MassReader object containing the mass data.
126+
unstack (bool): Whether to unstack the histograms.
127+
output (str): The output file path.
128+
**kwargs: Additional keyword arguments.
129+
"""
130+
# Set the style
131+
plt.style.use(hep.style.CMS)
132+
# initialize the figure and setup axis
133+
_, ax = plt.subplots(dpi=100, figsize=(8, 8))
134+
hep.cms.label(data=False, rlabel="Masterclass", ax=ax)
135+
ax.set_xlabel("Transverse Mass [GeV]" if self.transverse_mass else "Invariant Mass [GeV]")
136+
ax.set_ylabel("Events")
137+
# create a config dict
138+
hist_kwargs = {
139+
"bins": self.bins,
140+
"stacked": not unstack,
141+
"histtype": "stepfilled" if not unstack else "step",
142+
}
143+
# plot the data
144+
channels, masses = self.items()
145+
ax.hist(masses, **hist_kwargs, label=channels)
146+
# add legend and grid
147+
ax.legend()
148+
ax.grid()
149+
# save the figure
150+
plt.savefig(output, dpi=300, bbox_inches="tight")
151+
print(f"Saved plot to '{output}'")
152+
# show the plot
153+
plt.show()
154+
155+
156+
class WReader(BaseReader):
157+
""" Calculates the ratio of W+ to W- events.
132158
"""
133-
# Set the style
134-
plt.style.use(hep.style.CMS)
135-
# initialize the figure and setup axis
136-
_, ax = plt.subplots(dpi=100, figsize=(8, 8))
137-
hep.cms.label(data=False, rlabel="Masterclass", ax=ax)
138-
ax.set_xlabel("Transverse Mass [GeV]" if reader.transverse_mass else "Invariant Mass [GeV]")
139-
ax.set_ylabel("Events")
140-
# create a config dict
141-
hist_kwargs = {
142-
"bins": reader.bins,
143-
"stacked": not unstack,
144-
"histtype": "stepfilled" if not unstack else "step",
145-
}
146-
# plot the data
147-
channels, masses = reader.items()
148-
ax.hist(masses, **hist_kwargs, label=channels)
149-
# add legend and grid
150-
ax.legend()
151-
ax.grid()
152-
# save the figure
153-
plt.savefig(output, dpi=300, bbox_inches="tight")
154-
print(f"Saved plot to '{output}'")
155-
# show the plot
156-
plt.show()
159+
def __init__(self, files: dict[str, list[str]]):
160+
161+
super(WReader, self).__init__(files)
162+
163+
self.wp = len(self.data["Wp"])
164+
self.wm = len(self.data["Wm"])
165+
166+
if self.wm == 0:
167+
print("!!!\tNo W- events found. Skipping W+ to W- ratio calculation.")
168+
else:
169+
self.w_ratio = self.wp / self.wm
170+
self.w_error = sqrt(self.w_ratio / self.wm * (1 + self.w_ratio))
171+
172+
def print_ratio(self):
173+
print(40 * "*")
174+
print("***\tCalculating W+ to W- ratio...")
175+
print(f"***\tFound: {self.wp} W+ events and {self.wm} W- events.")
176+
print(f"***\tW+ to W- ratio: {self.w_ratio:.5f} ± {self.w_error:.5f} (stat)")
177+
print(40 * "*")
157178

158179

159180
def main():
160181
parser = parser_setup()
161182
args = parser.parse_args()
162183
if args.output is None:
163184
args.output = f"./{datetime.now().strftime('%d%m%Y')}.png"
164-
files = get_files_by_channel(args.channel, args.input)
165-
if not files:
166-
raise Exception(f"No files found for channel '{args.channel}' in folder '{args.input}'")
185+
# Handle channels
186+
channels = resolve_channels(args.channel)
187+
188+
if not args.skip_w_ratio:
189+
w_files = get_files_by_channel(["Wp", "Wm"], args.input)
190+
w_reader = WReader(files=w_files)
191+
w_reader.print_ratio()
192+
input("Press Enter to continue to the plot...")
167193

194+
files = get_files_by_channel(channels, args.input)
168195
reader = MassReader(files=files, n_bins=args.n_bins, x_min=args.min, transverse_mass=args.transverse_mass)
169-
reader.w_ratio()
170-
input("Press Enter to continue to the plot...")
171-
plot_masses(reader, **args.__dict__)
196+
reader.plot_masses(**args.__dict__)
172197

173198

174199
if __name__ == "__main__":

0 commit comments

Comments
 (0)