Skip to content

Commit 62f2a01

Browse files
committed
Big update
1 parent 9b1d49c commit 62f2a01

28 files changed

Lines changed: 3089 additions & 156 deletions

src/graphnet/data/dataloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.utils.data
66
from torch_geometric.data import Batch, Data
77

8-
from graphnet.data.dataset import Dataset
8+
from graphnet.data.dataset import Dataset, EnsembleDataset
99
from graphnet.utilities.config import DatasetConfig
1010

1111

@@ -81,5 +81,5 @@ def from_dataset_config(
8181
"need to specify `shuffle` as an argument."
8282
)
8383
dataset = Dataset.from_config(config)
84-
assert isinstance(dataset, Dataset)
84+
assert isinstance(dataset, Union[Dataset, EnsembleDataset])
8585
return cls(dataset, **kwargs)

src/graphnet/data/dataset/dataset.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,15 @@ def from_config( # type: ignore[override]
169169

170170
if isinstance(cfg["path"], list):
171171
sources = []
172+
msg = f"Constructing {len(cfg['path'])} datasets, with selection: {source.selection}"
173+
msg_bool = True
172174
for path in cfg["path"]:
173175
cfg["path"] = path
174-
sources.append(source._dataset_class(**cfg))
176+
sources.append(source._dataset_class(**cfg, verbose=False))
177+
if msg_bool:
178+
sources[-1].info(msg)
179+
msg_bool = False
180+
175181
source = EnsembleDataset(sources)
176182
return source
177183
else:
@@ -259,6 +265,8 @@ def __init__(
259265
loss_weight_default_value: Optional[float] = None,
260266
seed: Optional[int] = None,
261267
labels: Optional[Dict[str, Any]] = None,
268+
use_super_selection: bool = False,
269+
verbose: bool = True,
262270
):
263271
"""Construct Dataset.
264272
@@ -311,6 +319,11 @@ def __init__(
311319
NOTE: DEPRECATED Use `data_representation` instead.
312320
# DEPRECATION: REMOVE AT 2.0 LAUNCH
313321
# See https://github.com/graphnet-team/graphnet/issues/647
322+
use_super_selection: If True, the string selection is handled by
323+
the query function of the dataset class, rather than
324+
pd.DataFrame.query. Defaults to False and should
325+
only be used with sqlite.
326+
verbose: Whether to print the selection info
314327
"""
315328
# Base class constructor
316329
super().__init__(name=__name__, class_name=self.__class__.__name__)
@@ -354,6 +367,7 @@ def __init__(
354367
self._data_representation = deepcopy(data_representation)
355368
self._labels = labels
356369
self._string_column = data_representation._detector.string_index_name
370+
self._use_super_selection = use_super_selection
357371

358372
if node_truth is not None:
359373
assert isinstance(node_truth_table, str)
@@ -404,6 +418,7 @@ def __init__(
404418
self,
405419
index_column=index_column,
406420
seed=seed,
421+
use_super_selection=self._use_super_selection,
407422
)
408423

409424
if self._labels is not None:
@@ -419,7 +434,8 @@ def __init__(
419434
self._indices = self._get_all_indices()
420435
elif isinstance(selection, str):
421436
self._indices = self._resolve_string_selection_to_indices(
422-
selection
437+
selection,
438+
verbose=verbose,
423439
)
424440
else:
425441
self._indices = selection
@@ -528,7 +544,7 @@ def __getitem__(self, sequential_index: int) -> Data:
528544

529545
# Internal method(s)
530546
def _resolve_string_selection_to_indices(
531-
self, selection: str
547+
self, selection: str, verbose: bool = True
532548
) -> List[int]:
533549
"""Resolve selection as string to list of indices.
534550
@@ -537,7 +553,9 @@ def _resolve_string_selection_to_indices(
537553
fixed number of events to randomly sample, e.g., ``` "10000 random
538554
events ~ event_no % 5 > 0" "20% random events ~ event_no % 5 > 0" ```
539555
"""
540-
return self._string_selection_resolver.resolve(selection)
556+
return self._string_selection_resolver.resolve(
557+
selection, verbose=verbose
558+
)
541559

542560
def _remove_missing_columns(self) -> None:
543561
"""Remove columns that are not present in the input file.
@@ -585,7 +603,7 @@ def _remove_missing_columns(self) -> None:
585603
def _check_missing_columns(
586604
self,
587605
columns: List[str],
588-
table: str,
606+
table: Union[str, List[str]],
589607
) -> List[str]:
590608
"""Return a list missing columns in `table`."""
591609
for column in columns:
@@ -594,13 +612,26 @@ def _check_missing_columns(
594612
table=table, columns=[column], sequential_index=0
595613
)
596614
except ColumnMissingException:
597-
if table not in self._missing_variables:
598-
self._missing_variables[table] = []
599-
self._missing_variables[table].append(column)
615+
if isinstance(table, str):
616+
if table not in self._missing_variables:
617+
self._missing_variables[table] = []
618+
self._missing_variables[table].append(column)
619+
elif isinstance(table, list):
620+
for t in table:
621+
if t not in self._missing_variables:
622+
self._missing_variables[t] = []
623+
self._missing_variables[t].append(column)
600624
except IndexError:
601625
self.warning(f"Dataset contains no entries for {column}")
602-
603-
return self._missing_variables.get(table, [])
626+
if isinstance(table, str):
627+
missing_variables = self._missing_variables.get(table, [])
628+
elif isinstance(table, list):
629+
missing_variables = [
630+
value
631+
for key, value in self._missing_variables.items()
632+
if key in table
633+
]
634+
return missing_variables
604635

605636
def _query(
606637
self, sequential_index: int
@@ -677,10 +708,13 @@ def _create_graph(
677708
"""
678709
# Convert truth to dict
679710
if len(truth.shape) == 1:
680-
truth = truth.reshape(1, -1)
681-
truth_dict = {
682-
key: truth[:, index] for index, key in enumerate(self._truth)
683-
}
711+
truth_dict = {
712+
key: truth[0][index] for index, key in enumerate(self._truth)
713+
}
714+
else:
715+
truth_dict = {
716+
key: truth[:, index] for index, key in enumerate(self._truth)
717+
}
684718

685719
# Define custom labels
686720
labels_dict = self._get_labels(truth_dict)

src/graphnet/data/dataset/sqlite/sqlite_dataset.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,19 @@ def query_table(
7272
f"{self._index_column} = {index} and {selection}"
7373
)
7474

75-
result = self._conn.execute(
76-
f"SELECT {columns} FROM {table} WHERE "
77-
f"{combined_selections}"
78-
).fetchall()
75+
if isinstance(table, list):
76+
SELECT_QUERY = f"SELECT {columns} FROM"
77+
JOIN_TABLES = " JOIN ".join(table)
78+
USING_CLAUSE = f"USING({self._index_column})"
79+
WHERE_CLAUSE = f"WHERE {combined_selections}"
80+
FULL_QUERY = f"{SELECT_QUERY} {JOIN_TABLES} {USING_CLAUSE} {WHERE_CLAUSE}"
81+
result = self._conn.execute(FULL_QUERY).fetchall()
82+
else:
83+
result = self._conn.execute(
84+
f"SELECT {columns} FROM {table} WHERE "
85+
f"{combined_selections}"
86+
).fetchall()
87+
7988
except sqlite3.OperationalError as e:
8089
if "no such column" in str(e):
8190
raise ColumnMissingException(str(e))
@@ -151,3 +160,18 @@ def _close_connection(self) -> "SQLiteDataset":
151160
self._all_connections_established = False
152161
self._conn = None
153162
return self
163+
164+
def _join_tables(self, tables, columns):
165+
"""Join tables in the SQLite database."""
166+
# Check(s)
167+
if not isinstance(tables, list):
168+
raise TypeError("Input must be a list of table names.")
169+
if len(tables) == 0:
170+
raise ValueError(
171+
"Input list must contain at least one table name."
172+
)
173+
tables = ", ".join(tables)
174+
self._conn.execute("DROP VIEW IF EXISTS combined_table")
175+
self._conn.execute(
176+
f"CREATE VIEW combined_table AS SELECT {columns} FROM {tables} OUTER JOIN {tables} ON({self._index_column})"
177+
)

src/graphnet/data/utilities/sqlite_utilities.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,179 @@ def create_table_and_save_to_sql(
206206
integer_primary_key=integer_primary_key,
207207
)
208208
save_to_sql(df, table_name=table_name, database_path=database_path)
209+
210+
211+
def get_first_pulse_times(
212+
database_path: str,
213+
pulses_table_name: str = "SRTInIcePulses",
214+
time_column: str = "dom_time",
215+
index_column: str = "event_no",
216+
) -> pd.DataFrame:
217+
"""Get the first pulse time for each event.
218+
219+
Args:
220+
database_path: Path to the database.
221+
pulses_table_name: Name of the pulses table.
222+
time_column: Name of the time column in the pulses table.
223+
index_column: Name of the index column in the pulses table.
224+
225+
Returns:
226+
DataFrame with two columns: `event_no` and `first_pulse_time`.
227+
"""
228+
query = (
229+
f"SELECT {index_column}, MIN({time_column}) AS first_pulse_time "
230+
f"FROM {pulses_table_name} "
231+
f"GROUP BY {index_column};"
232+
)
233+
return query_database(database_path, query)
234+
235+
236+
def add_first_pulse_time_to_truth(
237+
database_path: str,
238+
truth_table_name: str = "truth",
239+
pulses_table_name: str = "SRTInIcePulses",
240+
time_column: str = "dom_time",
241+
index_column: str = "event_no",
242+
) -> None:
243+
"""Add the first pulse time to the truth table.
244+
245+
Args:
246+
database_path: Path to the database.
247+
truth_table_name: Name of the truth table.
248+
pulses_table_name: Name of the pulses table.
249+
time_column: Name of the time column in the pulses table.
250+
index_column: Name of the index column in both tables.
251+
"""
252+
253+
# Get first pulse times
254+
df = get_first_pulse_times(
255+
database_path=database_path,
256+
pulses_table_name=pulses_table_name,
257+
time_column=time_column,
258+
index_column=index_column,
259+
)
260+
print(f"Finished getting first pulse times for {len(df)} events.")
261+
# Create temporary table for first pulse times
262+
temp_table_name = "temp_first_pulse_times"
263+
264+
query = f"DROP TABLE IF EXISTS {temp_table_name};"
265+
run_sql_code(database_path, query)
266+
267+
create_table(
268+
columns=["event_no", "first_pulse_time"],
269+
table_name=temp_table_name,
270+
database_path=database_path,
271+
index_column=index_column,
272+
default_type="FLOAT",
273+
integer_primary_key=True,
274+
)
275+
print(f"Created temporary table {temp_table_name} for first pulse times.")
276+
# Save first pulse times to temporary table
277+
save_to_sql(
278+
df=df,
279+
table_name=temp_table_name,
280+
database_path=database_path,
281+
)
282+
283+
# Create the column and update it in the truth table remove if already exists
284+
query = (
285+
f"ALTER TABLE {truth_table_name} "
286+
f"ADD COLUMN first_pulse_time FLOAT;"
287+
)
288+
print(f"Adding column 'first_pulse_time' to {truth_table_name}.")
289+
290+
run_sql_code(database_path, query)
291+
query = (
292+
f"UPDATE {truth_table_name} "
293+
f"SET first_pulse_time = (SELECT first_pulse_time "
294+
f"FROM {temp_table_name} "
295+
f"WHERE {temp_table_name}.{index_column} = {truth_table_name}.{index_column});"
296+
)
297+
298+
run_sql_code(database_path, query)
299+
print(
300+
f"Updated {truth_table_name} with first pulse times from {temp_table_name}."
301+
)
302+
# Drop the temporary table
303+
query = f"DROP TABLE IF EXISTS {temp_table_name};"
304+
print(f"Dropping temporary table {temp_table_name}.")
305+
run_sql_code(database_path, query)
306+
307+
308+
def add_starting(
309+
database_path: str,
310+
truth_table_name: str = "truth",
311+
containment_column: str = "containment_type",
312+
index_column: str = "event_no",
313+
) -> None:
314+
"""Add the starting to the truth table.
315+
316+
Args:
317+
database_path: Path to the database.
318+
truth_table_name: Name of the truth table.
319+
index_column: Name of the index column in both tables.
320+
"""
321+
322+
# mapping from containment enum to starting
323+
map_dict = {
324+
1: 0, # no intersect: not starting
325+
2: 0, # through-going: not starting
326+
3: 1, # contained: starting
327+
4: 1, # tau-to-mu: starting
328+
5: 1, # uncontained-starting: starting
329+
6: 0, # stopping: not starting
330+
7: 0, # decayed: not starting
331+
8: 0, # through-going bundle: not starting
332+
9: 0, # stopping bundle: not starting
333+
10: 1, # partial-contained: starting
334+
}
335+
336+
containment_type_query = (
337+
f"SELECT {index_column}, {containment_column} "
338+
f"FROM {truth_table_name};"
339+
)
340+
341+
containment_df = query_database(database_path, containment_type_query)
342+
343+
# convert containment type to starting using map_dict
344+
containment_df["starting"] = (
345+
containment_df[containment_column].astype(int).map(map_dict)
346+
)
347+
348+
temp_table_name = "temp_starting"
349+
query = f"DROP TABLE IF EXISTS {temp_table_name};"
350+
run_sql_code(database_path, query)
351+
352+
create_table(
353+
columns=[index_column, "starting"],
354+
table_name=temp_table_name,
355+
database_path=database_path,
356+
index_column=index_column,
357+
default_type="INTEGER",
358+
integer_primary_key=True,
359+
)
360+
361+
print(f"Created temporary table {temp_table_name} for starting.")
362+
# Save starting to temporary table
363+
save_to_sql(
364+
df=containment_df[[index_column, "starting"]],
365+
table_name=temp_table_name,
366+
database_path=database_path,
367+
)
368+
# Create the column and update it in the truth table remove if already exists
369+
query = f"ALTER TABLE {truth_table_name} " f"ADD COLUMN starting INTEGER;"
370+
print(f"Adding column 'starting' to {truth_table_name}.")
371+
run_sql_code(database_path, query)
372+
query = (
373+
f"UPDATE {truth_table_name} "
374+
f"SET starting = (SELECT starting "
375+
f"FROM {temp_table_name} "
376+
f"WHERE {temp_table_name}.{index_column} = {truth_table_name}.{index_column});"
377+
)
378+
379+
run_sql_code(database_path, query)
380+
print(f"Updated {truth_table_name} with starting from {temp_table_name}.")
381+
# Drop the temporary table
382+
query = f"DROP TABLE IF EXISTS {temp_table_name};"
383+
print(f"Dropping temporary table {temp_table_name}.")
384+
run_sql_code(database_path, query)

0 commit comments

Comments
 (0)