Skip to content

Commit 78843af

Browse files
committed
runID assertion and PR suggestions
1 parent c31582f commit 78843af

1 file changed

Lines changed: 28 additions & 2 deletions

File tree

src/graphnet/datasets/snowstorm_dataset.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from graphnet.data.utilities import query_database
1313
from graphnet.models.graphs import GraphDefinition
1414

15+
AVAILABLE_RUN_IDS = [
16+
*list(range(22010, 22019)),
17+
*list(range(22042, 22051)),
18+
*list(range(22078, 22087)),
19+
]
20+
1521

1622
class SnowStormDataset(IceCubeHostedDataset):
1723
"""IceCube SnowStorm simulation dataset.
@@ -23,7 +29,7 @@ class SnowStormDataset(IceCubeHostedDataset):
2329
"""
2430

2531
_experiment = "IceCube SnowStorm dataset"
26-
_creator = "Severin Magel"
32+
_creator = "Aske Rosted"
2733
_citation = "arXiv:1909.01530"
2834
_available_backends = ["sqlite"]
2935

@@ -45,7 +51,27 @@ def __init__(
4551
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
4652
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
4753
):
48-
"""Initialize SnowStorm dataset."""
54+
"""Construct SnowStormDataset.
55+
56+
Args:
57+
run_ids: List of RunIDs to include.
58+
graph_definition: Method that defines the data representation.
59+
download_dir: Directory to download dataset to.
60+
truth (Optional): List of event-level truth to include. Will
61+
include all available information if not given.
62+
features (Optional): List of input features from pulsemap to use.
63+
If not given, all available features will be
64+
used.
65+
train_dataloader_kwargs (Optional): Arguments for the training
66+
DataLoader. Default None.
67+
validation_dataloader_kwargs (Optional): Arguments for the
68+
validation DataLoader, Default None.
69+
test_dataloader_kwargs (Optional): Arguments for the test
70+
DataLoader. Default None.
71+
"""
72+
assert all(
73+
[i in AVAILABLE_RUN_IDS for i in run_ids]
74+
), f"RunIDs must be in {AVAILABLE_RUN_IDS}. You provided {run_ids}"
4975
self._run_ids = run_ids
5076
self._zipped_files = [
5177
os.path.join(self._data_root_dir, f"{s}.tar.gz") for s in run_ids

0 commit comments

Comments
 (0)