Skip to content

Commit b4f7cbc

Browse files
authored
Merge pull request graphnet-team#775 from Aske-Rosted/file_list_ensemble
Ensemble from list of paths
2 parents 92b150a + 79b0479 commit b4f7cbc

2 files changed

Lines changed: 29 additions & 1 deletion

File tree

src/graphnet/data/dataset/dataset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,16 @@ def from_config( # type: ignore[override]
141141
cfg["graph_definition"] = parse_graph_definition(cfg)
142142
if cfg["labels"] is not None:
143143
cfg["labels"] = parse_labels(cfg)
144-
return source._dataset_class(**cfg)
144+
145+
if isinstance(cfg["path"], list):
146+
sources = []
147+
for path in cfg["path"]:
148+
cfg["path"] = path
149+
sources.append(source._dataset_class(**cfg))
150+
source = EnsembleDataset(sources)
151+
return source
152+
else:
153+
return source._dataset_class(**cfg)
145154

146155
@classmethod
147156
def concatenate(

tests/utilities/test_dataset_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,22 @@ def test_dataset_config_files(backend: str) -> None:
269269
)
270270
== 0
271271
)
272+
273+
274+
@pytest.mark.order(6)
275+
@pytest.mark.parametrize("backend", ["sqlite"])
276+
def test_multiple_dataset_config_dict_selection(backend: str) -> None:
277+
"""Test constructing Dataset with multiple data paths."""
278+
# Arrange
279+
config_path = CONFIG_PATHS[backend]
280+
281+
# Single dataset
282+
config = DatasetConfig.load(config_path)
283+
dataset = Dataset.from_config(config)
284+
# Construct multiple datasets
285+
config_ensemble = DatasetConfig.load(config_path)
286+
config_ensemble.path = [config_ensemble.path, config_ensemble.path]
287+
288+
ensemble_dataset = Dataset.from_config(config_ensemble)
289+
290+
assert len(dataset) * 2 == len(ensemble_dataset)

0 commit comments

Comments
 (0)