-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtruth_match.py
More file actions
101 lines (87 loc) · 3.09 KB
/
truth_match.py
File metadata and controls
101 lines (87 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
"""Standalone truth matching."""
import argparse
import logging
from collections import Counter
from tqdm import tqdm
import ROOT
def find_MC_track(track, event):
"""Match track to MC track.
returns MC track index or -1
"""
link = event.Digi_TargetClusterHits2MCPoints[0]
points = track.getPoints()
track_ids = []
for p in points:
wlist = link.wList(p.getRawMeasurement().getDetId())
for index, _ in wlist:
point = event.AdvTargetPoint[index]
track_id = point.GetTrackID()
if track_id == -2:
continue
track_ids.append(track_id)
if not track_ids:
return -1
most_common_track, count = Counter(track_ids).most_common(1)[0]
if count >= len(points) * 0.7:
# truth match if ≥ 70 % of hits are related to a single MCTrack
return most_common_track
return -1
def match_vertex(vertex, event):
"""Match vertex to start of its matched MC tracks.
returns TVector3 or None
"""
tracks = [vertex.getParameters(i).getTrack() for i in range(vertex.getNTracks())]
matched_tracks = [track for track in tracks if track.getMcTrackId() >= 0]
if len(matched_tracks) < 2:
return None
mc_tracks = [
event.MCTrack[track.getMcTrackId()]
for track in matched_tracks
if track.getMcTrackId() < len(event.MCTrack)
]
mother_ids = [track.GetMotherId() for track in mc_tracks]
if not mother_ids:
return None
most_common_mother, count = Counter(mother_ids).most_common(1)[0]
if count >= len(matched_tracks) * 0.7:
# truth match if ≥ 70 % of hits are related to a single MCTrack
for mc_track in mc_tracks:
if mc_track.GetMotherId() == most_common_mother:
true_vertex = ROOT.TVector3()
mc_track.GetStartVertex(true_vertex)
return true_vertex
return None
def main():
"""Truth match tracks and vertices."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"-f",
"--inputfile",
help="""Simulation results to use as input."""
"""Supports retrieving file from EOS via the XRootD protocol.""",
required=True,
)
parser.add_argument(
"-o",
"--outputfile",
help="""File to write the filtered tree to."""
"""Will be recreated if it already exists.""",
)
args = parser.parse_args()
if not args.outputfile:
args.outputfile = args.inputfile.removesuffix(".root") + "_MCTruth.root"
inputfile = ROOT.TFile.Open(args.inputfile, "read")
tree = inputfile.cbmsim
outputfile = ROOT.TFile.Open(args.outputfile, "recreate")
out_tree = tree.CloneTree(0)
for event in tqdm(tree, desc="Event loop: ", total=tree.GetEntries()):
for track in event.genfit_tracks:
track_id = find_MC_track(track, event)
track.setMcTrackId(track_id)
out_tree.Fill()
out_tree.Write()
outputfile.Write()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()