Skip to content

Commit ac63d3a

Browse files
committed
auto formatting.
1 parent 206c725 commit ac63d3a

22 files changed

Lines changed: 2583 additions & 1831 deletions

examples/01_icetray/04_i3_module_in_native_icetray_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def apply_to_files(
7171
],
7272
filename=output_folder + "/" + i3_file.split("/")[-1],
7373
)
74-
74+
7575
# Called in Tray.Execute (Running the Model)
7676
tray.Execute()
7777
tray.Finish()
@@ -109,7 +109,6 @@ def main() -> None:
109109
model_name="graphnet_deployment_example",
110110
)
111111

112-
113112
# Apply module to files in IceTray
114113
apply_to_files(
115114
i3_files=input_files,

src/graphnet/data/dataconverter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def __init__(
9090
self._file_reader.set_extractors(extractors=extractors)
9191

9292
@final
93-
def __call__(self, input_file: Union[str, List[str]], folder: bool = True) -> None:
93+
def __call__(
94+
self, input_file: Union[str, List[str]], folder: bool = True
95+
) -> None:
9496
"""Extract data from files in `input_dir` and save to disk.
9597
9698
Args:
@@ -113,7 +115,7 @@ def __call__(self, input_file: Union[str, List[str]], folder: bool = True) -> No
113115
for file in input_files
114116
]
115117
else:
116-
gcd_file = '/cvmfs/icecube.opensciencegrid.org/data/GCD/GeoCalibDetectorStatus_2020.Run134142.Pass2_V0.i3.gz'
118+
gcd_file = "/cvmfs/icecube.opensciencegrid.org/data/GCD/GeoCalibDetectorStatus_2020.Run134142.Pass2_V0.i3.gz"
117119
file_list = []
118120
for file in input_file:
119121
file_list.append(I3FileSet(i3_file=file, gcd_file=gcd_file))

src/graphnet/data/extractors/icecube/i3bundleeliminator.py

Lines changed: 274 additions & 193 deletions
Large diffs are not rendered by default.

src/graphnet/data/extractors/icecube/i3bundleextractor.py

Lines changed: 189 additions & 129 deletions
Large diffs are not rendered by default.

src/graphnet/data/extractors/icecube/i3featureextractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,4 +300,4 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
300300
for truth_flag in pulses:
301301
output["truth_flag"].append(truth_flag)
302302

303-
return output
303+
return output

src/graphnet/data/extractors/icecube/i3featureextractorlegacy.py

Lines changed: 120 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,25 @@
1515
from icecube import dataclasses
1616

1717
import pandas as pd
18-
from collections import defaultdict
18+
from collections import defaultdict
19+
1920

2021
class I3FeatureExtractorLegacy(I3Extractor):
2122
"""Base class for extracting specific, reconstructed features."""
2223

23-
def __init__(self, pulsemap: str, quantiles_time: List[Any], quantiles_charge: List[Any],
24-
is_data: bool = False):
24+
def __init__(
25+
self,
26+
pulsemap: str,
27+
quantiles_time: List[Any],
28+
quantiles_charge: List[Any],
29+
is_data: bool = False,
30+
):
2531
"""Construct I3FeatureExtractorLegacy.
2632
2733
Args:
2834
pulsemap: Name of the pulse (series) map for which to extract
2935
reconstructed features.
30-
quantiles_time:
36+
quantiles_time:
3137
quantiles_charge
3238
"""
3339
# Member variable(s)
@@ -39,6 +45,7 @@ def __init__(self, pulsemap: str, quantiles_time: List[Any], quantiles_charge: L
3945
# Base class constructor
4046
super().__init__(pulsemap)
4147

48+
4249
class I3FeatureExtractorLegacyIceCube(I3FeatureExtractorLegacy):
4350
"""Class for extracting reconstructed features for IceCube-86."""
4451

@@ -54,6 +61,7 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
5461
in pure-python format.
5562
"""
5663

64+
5765
class I3FeatureExtractorLegacyIceCube86(I3FeatureExtractorLegacy):
5866
"""Class for extracting reconstructed features for IceCube-86."""
5967

@@ -124,8 +132,8 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
124132

125133
event_time = frame["I3EventHeader"].start_time.mod_julian_day_double
126134

127-
#print(frame["I3EventHeader"].event_id, frame["I3EventHeader"].sub_event_id, "graphnet run")
128-
#print(len(om_keys), 'graphnet')
135+
# print(frame["I3EventHeader"].event_id, frame["I3EventHeader"].sub_event_id, "graphnet run")
136+
# print(len(om_keys), 'graphnet')
129137

130138
if self._is_data == False:
131139
leading = self._get_leading_particle(frame=frame)
@@ -167,9 +175,9 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
167175

168176
# Loop over pulses for each OM
169177
pulses = data[om_key]
170-
#print(pulses)
178+
# print(pulses)
171179

172-
for _,pulse in enumerate(pulses):
180+
for _, pulse in enumerate(pulses):
173181

174182
output["charge"].append(
175183
getattr(pulse, "charge", padding_value)
@@ -196,15 +204,21 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
196204
output["event_time"].append(event_time)
197205

198206
if self._is_data == False:
199-
output['r'].append(
200-
phys_services.I3Calculator.closest_approach_distance(leading, self._gcd_dict[om_key].position)
207+
output["r"].append(
208+
phys_services.I3Calculator.closest_approach_distance(
209+
leading, self._gcd_dict[om_key].position
210+
)
211+
)
212+
output["residual"].append(
213+
phys_services.I3Calculator.time_residual(
214+
leading,
215+
self._gcd_dict[om_key].position,
216+
getattr(pulse, "time", padding_value),
201217
)
202-
output['residual'].append(
203-
phys_services.I3Calculator.time_residual(leading, self._gcd_dict[om_key].position, getattr(pulse, "time", padding_value))
204218
)
205219
else:
206-
output['r'].append(padding_value)
207-
output['residual'].append(padding_value)
220+
output["r"].append(padding_value)
221+
output["residual"].append(padding_value)
208222

209223
# Pulse flags
210224
flags = getattr(pulse, "flags", padding_value)
@@ -216,73 +230,108 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
216230
output["awtd"].append(self._parse_awtd_flag(pulse))
217231

218232
# Convert Event Info Into Dataframe
219-
evt_pulses = pd.DataFrame({"charge": output['charge'],
220-
"dom_time": output['dom_time'],
221-
"width": output['width'],
222-
"dom_x": output['dom_x'],
223-
"dom_y": output['dom_y'],
224-
"dom_z": output['dom_z'],
225-
"pmt_area": output['pmt_area'],
226-
"rde": output['rde'],
227-
"is_bright_dom": output['is_bright_dom'],
228-
"is_bad_dom": output['is_bad_dom'],
229-
"is_saturated_dom": output['is_saturated_dom'],
230-
"is_errata_dom": output['is_errata_dom'],
231-
"event_time": output['event_time'],
232-
"hlc": output['hlc'],
233-
"awtd": output['awtd'],
234-
"string": output['string'],
235-
"pmt_number": output['pmt_number'],
236-
"dom_number": output['dom_number'],
237-
"dom_type": output['dom_type'],
238-
"r": output['r'],
239-
"residual": output['residual'],},)
240-
233+
evt_pulses = pd.DataFrame(
234+
{
235+
"charge": output["charge"],
236+
"dom_time": output["dom_time"],
237+
"width": output["width"],
238+
"dom_x": output["dom_x"],
239+
"dom_y": output["dom_y"],
240+
"dom_z": output["dom_z"],
241+
"pmt_area": output["pmt_area"],
242+
"rde": output["rde"],
243+
"is_bright_dom": output["is_bright_dom"],
244+
"is_bad_dom": output["is_bad_dom"],
245+
"is_saturated_dom": output["is_saturated_dom"],
246+
"is_errata_dom": output["is_errata_dom"],
247+
"event_time": output["event_time"],
248+
"hlc": output["hlc"],
249+
"awtd": output["awtd"],
250+
"string": output["string"],
251+
"pmt_number": output["pmt_number"],
252+
"dom_number": output["dom_number"],
253+
"dom_type": output["dom_type"],
254+
"r": output["r"],
255+
"residual": output["residual"],
256+
},
257+
)
258+
241259
# Produce Quantile Information of Each DOM
242-
t_quantiles = evt_pulses.groupby(["string", "dom_number"])['dom_time'].quantile(self._quantiles_time).unstack().reset_index()
260+
t_quantiles = (
261+
evt_pulses.groupby(["string", "dom_number"])["dom_time"]
262+
.quantile(self._quantiles_time)
263+
.unstack()
264+
.reset_index()
265+
)
243266
for quant in self._quantiles_time:
244-
t_quantiles = t_quantiles.rename(columns={quant: f't{int(1000*quant)}'})
267+
t_quantiles = t_quantiles.rename(
268+
columns={quant: f"t{int(1000*quant)}"}
269+
)
245270

246-
evt_pulses['qcumsum'] = evt_pulses.groupby(["string", "dom_number"])['charge'].cumsum()
247-
q_quantiles = evt_pulses.groupby(["string", "dom_number"])['qcumsum'].quantile(self._quantiles_charge).unstack().reset_index()
248-
evt_pulses.drop(columns=['qcumsum'], inplace=True)
271+
evt_pulses["qcumsum"] = evt_pulses.groupby(["string", "dom_number"])[
272+
"charge"
273+
].cumsum()
274+
q_quantiles = (
275+
evt_pulses.groupby(["string", "dom_number"])["qcumsum"]
276+
.quantile(self._quantiles_charge)
277+
.unstack()
278+
.reset_index()
279+
)
280+
evt_pulses.drop(columns=["qcumsum"], inplace=True)
249281
for quant in self._quantiles_charge:
250-
q_quantiles = q_quantiles.rename(columns={quant: f'q{int(1000*quant)}'})
251-
q_total = evt_pulses.groupby(["string", "dom_number"], as_index=False)['charge'].sum()
282+
q_quantiles = q_quantiles.rename(
283+
columns={quant: f"q{int(1000*quant)}"}
284+
)
285+
q_total = evt_pulses.groupby(["string", "dom_number"], as_index=False)[
286+
"charge"
287+
].sum()
252288
# Extrac the Minimum Pulse Time of Each Dom
253-
min_times = evt_pulses.loc[evt_pulses.groupby(["string", "dom_number"], as_index=True)['dom_time'].idxmin()]
289+
min_times = evt_pulses.loc[
290+
evt_pulses.groupby(["string", "dom_number"], as_index=True)[
291+
"dom_time"
292+
].idxmin()
293+
]
254294

255-
min_times = min_times.merge(t_quantiles, on = ["string", "dom_number"])
256-
min_times = min_times.merge(q_quantiles, on = ["string", "dom_number"])
295+
min_times = min_times.merge(t_quantiles, on=["string", "dom_number"])
296+
min_times = min_times.merge(q_quantiles, on=["string", "dom_number"])
257297

258-
min_times['adjusted_time'] = min_times["dom_time"] - min_times["dom_time"].min()
259-
260-
total_pulses = evt_pulses.groupby(["string", "dom_number"], as_index=False)['charge'].size()
298+
min_times["adjusted_time"] = (
299+
min_times["dom_time"] - min_times["dom_time"].min()
300+
)
261301

262-
min_times['dom_qtot'] = q_total['charge']
263-
min_times['dom_qtot_exc'] = q_total['charge']
264-
min_times['total_pulses'] = total_pulses['size']
302+
total_pulses = evt_pulses.groupby(
303+
["string", "dom_number"], as_index=False
304+
)["charge"].size()
265305

266-
bright_doms = min_times['dom_qtot']/frame['Homogenized_QTot_New'].value >= .4
306+
min_times["dom_qtot"] = q_total["charge"]
307+
min_times["dom_qtot_exc"] = q_total["charge"]
308+
min_times["total_pulses"] = total_pulses["size"]
267309

268-
min_times['bright_dom'] = bright_doms.to_numpy(dtype=float)
310+
bright_doms = (
311+
min_times["dom_qtot"] / frame["Homogenized_QTot_New"].value >= 0.4
312+
)
269313

270-
bad_doms = (min_times['is_errata_dom'] == 1) | (min_times['is_saturated_dom'] == 1)
271-
t_name_keys = [f't{int(1000*quant)}' for quant in self._quantiles_time]
272-
q_name_keys = [f'q{int(1000*quant)}' for quant in self._quantiles_charge]
314+
min_times["bright_dom"] = bright_doms.to_numpy(dtype=float)
315+
316+
bad_doms = (min_times["is_errata_dom"] == 1) | (
317+
min_times["is_saturated_dom"] == 1
318+
)
319+
t_name_keys = [f"t{int(1000*quant)}" for quant in self._quantiles_time]
320+
q_name_keys = [
321+
f"q{int(1000*quant)}" for quant in self._quantiles_charge
322+
]
273323

274324
for t_name in t_name_keys:
275325
min_times[t_name] = min_times[t_name] - min_times["dom_time"].min()
276326

277327
# Remove This
278-
#min_times.loc[bad_doms, t_name_keys] = -100
279-
#min_times.loc[bad_doms, q_name_keys] = -100
280-
min_times.loc[bad_doms, 'dom_qtot_exc'] = -100
281-
328+
# min_times.loc[bad_doms, t_name_keys] = -100
329+
# min_times.loc[bad_doms, q_name_keys] = -100
330+
min_times.loc[bad_doms, "dom_qtot_exc"] = -100
282331

283-
#print(min_times)
284-
output = min_times.to_dict(orient='list')
285-
#print(min_times)
332+
# print(min_times)
333+
output = min_times.to_dict(orient="list")
334+
# print(min_times)
286335
return output
287336

288337
def _get_relative_dom_efficiency(
@@ -323,11 +372,12 @@ def _parse_awtd_flag(
323372

324373
# Error Getting this for a certain set
325374
def _get_leading_particle(
326-
self, frame: "icetray.I3Frame",
327-
):
375+
self,
376+
frame: "icetray.I3Frame",
377+
):
328378

329379
try:
330-
tracklist = frame['MMCTrackList']
380+
tracklist = frame["MMCTrackList"]
331381

332382
max_energy = -1
333383
max_particle = tracklist[0]
@@ -338,6 +388,6 @@ def _get_leading_particle(
338388

339389
return max_particle.particle
340390
except:
341-
print('no mmctracklist')
342-
mctree = frame['I3MCTree_preMuonProp']
343-
return mctree[1]
391+
print("no mmctracklist")
392+
mctree = frame["I3MCTree_preMuonProp"]
393+
return mctree[1]

0 commit comments

Comments
 (0)