Skip to content

Commit 4bcc607

Browse files
authored
Merge pull request graphnet-team#876 from Aske-Rosted/lmdb_test_fix
Lmdb test fix
2 parents 8477757 + 0dead35 commit 4bcc607

2 files changed

Lines changed: 7 additions & 2 deletions

File tree

src/graphnet/data/dataset/lmdb/lmdb_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def _post_init(self) -> None:
169169
if self._pre_computed_representation is None:
170170
# Only check for missing columns if using raw tables
171171
self._remove_missing_columns()
172-
self._close_connection()
173172
if self._pre_computed_representation is not None:
174173
self._identify_missing_truth_labels()
174+
self._close_connection()
175175

176176
def _identify_missing_truth_labels(self) -> None:
177177
"""Identify missing truth labels in the pre-computed representation."""
@@ -430,3 +430,7 @@ def __getitem__(self, sequential_index: int) -> Any:
430430
data = super().__getitem__(sequential_index)
431431

432432
return data
433+
434+
def close(self) -> None:
435+
"""Close any open LMDB connections."""
436+
self._close_connection()

tests/data/test_dataconverters_and_datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@ def test_sqlite_to_lmdb_converter() -> None:
296296
truth=TRUTH.DEEPCORE,
297297
graph_definition=graph_definition,
298298
)
299-
300299
dataset_from_lmdb_raw = LMDBDataset(path, **opt_raw) # type: ignore
301300
dataset_sqlite = SQLiteDataset(
302301
get_file_path("sqlite"), **opt_raw # type: ignore[arg-type]
@@ -310,6 +309,7 @@ def test_sqlite_to_lmdb_converter() -> None:
310309
dataset_from_lmdb_raw[ix].x, dataset_sqlite[ix].x
311310
)
312311

312+
dataset_from_lmdb_raw.close() # Close connection
313313
# Test 2: Check that pre-computed representation matches real-time computed
314314
# The pre-computed representation field name is the class name
315315
pre_computed_field_name = graph_definition.__class__.__name__
@@ -361,3 +361,4 @@ def test_sqlite_to_lmdb_converter() -> None:
361361
)
362362
else:
363363
assert precomputed_truth == realtime_truth
364+
dataset_from_lmdb_precomputed.close() # Close connection

0 commit comments

Comments
 (0)