Skip to content

Commit 384644b

Browse files
authored
Merge pull request graphnet-team#794 from Aske-Rosted/writer_changes
Writer changes
2 parents 4bee06e + 50415bd commit 384644b

3 files changed

Lines changed: 135 additions & 46 deletions

File tree

src/graphnet/data/pre_configured/dataconverters.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Pre-configured combinations of writers and readers."""
22

3-
from typing import List, Union
3+
from typing import List, Union, Optional
44

55
from graphnet.data import DataConverter
66
from graphnet.data.readers import I3Reader, ParquetReader
@@ -68,6 +68,7 @@ def __init__(
6868
index_column: str = "event_no",
6969
num_workers: int = 1,
7070
i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore
71+
max_table_size: Optional[int] = None,
7172
):
7273
"""Convert I3 files to SQLite.
7374
@@ -92,10 +93,11 @@ def __init__(
9293
Defaults to 1 (no multiprocessing).
9394
i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to
9495
`NullSplitI3Filter`.
96+
max_table_size: Maximum size of the SQLite tables. Default None.
9597
"""
9698
super().__init__(
9799
file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters),
98-
save_method=SQLiteWriter(),
100+
save_method=SQLiteWriter(max_table_size=max_table_size),
99101
extractors=extractors,
100102
num_workers=num_workers,
101103
index_column=index_column,

src/graphnet/data/writers/sqlite_writer.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import os
88
from tqdm import tqdm
9-
from typing import List, Dict, Optional
9+
from typing import List, Dict, Optional, Union, Tuple
1010

1111
from graphnet.data.utilities import (
1212
create_table_and_save_to_sql,
@@ -100,6 +100,8 @@ def merge_files(
100100
files: List[str],
101101
output_dir: str,
102102
primary_key_rescue: str = "event_no",
103+
remove_originals: bool = False,
104+
reset_integer_primary_key: bool = False,
103105
) -> None:
104106
"""SQLite-specific method for merging output files/databases.
105107
@@ -117,6 +119,12 @@ def merge_files(
117119
primary_key_rescue: The name of the columns on which the primary
118120
key is constructed. This will only be used if it is not
119121
possible to infer the primary key name.
122+
remove_originals: If True, the original files will be removed
123+
after merging.
124+
reset_integer_primary_key: If True, the primary key will reset
125+
when merging the databases. This is useful if the files
126+
were not generated by the same process and therefore
127+
have overlapping index values.
120128
"""
121129
# Warnings
122130
if self._max_table_size:
@@ -136,14 +144,21 @@ def merge_files(
136144
if len(files) > 0:
137145
os.makedirs(output_dir, exist_ok=True)
138146
self.info(f"Merging {len(files)} database files")
139-
self._merge_databases(files=files, database_path=database_path)
147+
self._merge_databases(
148+
files=files,
149+
database_path=database_path,
150+
remove_originals=remove_originals,
151+
reset_integer_primary_key=reset_integer_primary_key,
152+
)
140153
else:
141154
self.warning("No database files given! Exiting.")
142155

143156
def _merge_databases(
144157
self,
145158
files: List[str],
146159
database_path: str,
160+
remove_originals: bool = False,
161+
reset_integer_primary_key: bool = False,
147162
) -> None:
148163
"""Merge the temporary databases.
149164
@@ -152,6 +167,10 @@ def _merge_databases(
152167
database_path: Path to a database, can be an empty path, where the
153168
databases listed in `files` will be merged into. If no database
154169
exists at the given path, one will be created.
170+
remove_originals: If True, the original files will be removed
171+
after merging.
172+
reset_integer_primary_key: If True, the primary key will reset
173+
when merging the databases.
155174
"""
156175
if os.path.exists(database_path):
157176
self.warning(
@@ -165,6 +184,7 @@ def _merge_databases(
165184
self._largest_table = 0
166185

167186
# Merge temporary databases into newly created one
187+
start_count = 0
168188
for file_count, input_file in tqdm(enumerate(files), colour="green"):
169189
# Extract table names and index column name in database
170190
try:
@@ -177,7 +197,18 @@ def _merge_databases(
177197
continue
178198
else:
179199
raise e
180-
200+
integer_primary_dict = None
201+
if reset_integer_primary_key:
202+
if list(tables.values())[0] is None:
203+
# Ensure that the first table is indexed
204+
temp_tables: Dict[str, Union[str, None]] = {
205+
k: v for k, v in tables.items() if v is not None
206+
}
207+
# re-add the non-indexed tables to the dictionary
208+
temp_tables.update(
209+
{k: v for k, v in tables.items() if v is None}
210+
)
211+
tables = temp_tables
181212
for table_name in tables.keys():
182213
# Extract all data in the table from the given database
183214
df = query_database(
@@ -190,6 +221,22 @@ def _merge_databases(
190221
True if tables[table_name] is not None else False
191222
)
192223

224+
if reset_integer_primary_key & (integer_primary_dict is None):
225+
integer_primary_dict, start_count = (
226+
self._generate_integer_primary_dict(
227+
start_count, df[primary_key].values
228+
)
229+
)
230+
df = self._map_integer_primary_keys(
231+
df, primary_key, integer_primary_dict
232+
)
233+
elif reset_integer_primary_key & (
234+
integer_primary_dict is not None
235+
):
236+
df = self._map_integer_primary_keys(
237+
df, primary_key, integer_primary_dict
238+
)
239+
193240
# Submit to new database
194241
create_table_and_save_to_sql(
195242
df=df,
@@ -216,8 +263,31 @@ def _merge_databases(
216263
"Maximum row count reached."
217264
f" Creating new partition at {database_path}"
218265
)
266+
if remove_originals:
267+
# Remove original files after all files have been merged
268+
for input_file in files:
269+
os.remove(input_file)
219270

220271
# Internal methods
272+
def _generate_integer_primary_dict(
273+
self, start_count: int, integer_primary_keys_array: List[int]
274+
) -> Tuple[Dict[int, int], int]:
275+
integer_primary_dict = {}
276+
for i, key in enumerate(integer_primary_keys_array):
277+
integer_primary_dict[key] = start_count + i
278+
return integer_primary_dict, start_count + len(
279+
integer_primary_keys_array
280+
)
281+
282+
def _map_integer_primary_keys(
283+
self,
284+
df: pd.DataFrame,
285+
key: str,
286+
integer_primary_dict: Union[Dict[int, int], None],
287+
) -> pd.DataFrame:
288+
assert integer_primary_dict is not None # mypy...
289+
df[key] = df[key].map(integer_primary_dict)
290+
return df
221291

222292
def _adjust_output_path(self, output_file: str) -> str:
223293
"""Adjust the file path to reflect that it is a partition."""

src/graphnet/utilities/filesys.py

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pathlib import Path
44
import re
5+
import os
56
from typing import List, Optional, Tuple, Union
67

78

@@ -31,7 +32,7 @@ def has_extension(filename: str, extensions: List[str]) -> bool:
3132

3233

3334
def find_i3_files(
34-
directories: Union[str, List[str]],
35+
inputs: Union[str, List[str]],
3536
gcd_rescue: Optional[str] = None,
3637
recursive: Optional[bool] = True,
3738
) -> Tuple[List[str], List[str]]:
@@ -42,7 +43,8 @@ def find_i3_files(
4243
in the directory.
4344
4445
Args:
45-
directories: Directories to search recursively for I3 files.
46+
inputs: Directories to search recursively for I3 files.
47+
Or list of I3 files.
4648
gcd_rescue: Path to the GCD that will be default if no GCD is present
4749
in the directory.
4850
recursive: Whether or not to search the directories recursively.
@@ -51,46 +53,61 @@ def find_i3_files(
5153
i3_list: Paths to I3 files in `directories`
5254
gcd_list: Paths to GCD files for each I3 file.
5355
"""
54-
if isinstance(directories, str):
55-
directories = [directories]
56+
if isinstance(inputs, str):
57+
inputs = [inputs]
5658

5759
# Output containers
5860
i3_files = []
5961
gcd_files = []
60-
61-
for directory in directories:
62-
63-
# Find all I3-like files in `directory`, may or may not be recursively.
64-
paths = []
65-
i3_patterns = ["*.bz2", "*.zst", "*.gz"]
66-
for i3_pattern in i3_patterns:
67-
if recursive:
68-
paths.extend(list(Path(directory).rglob(i3_pattern)))
69-
else:
70-
paths.extend(list(Path(directory).glob(i3_pattern)))
71-
72-
# Loop over all folders containing such I3-like files.
73-
folders = sorted(set([path.parent for path in paths]))
74-
for folder in folders:
75-
76-
# List all I3 and GCD files, respectively, in the current folder.
77-
folder_files = [
78-
str(path) for path in paths if path.parent == folder
79-
]
80-
folder_i3_files = list(filter(is_i3_file, folder_files))
81-
folder_gcd_files = list(filter(is_gcd_file, folder_files))
82-
83-
# Make sure that no more than one GCD file is found;
84-
# and use rescue file if none is found.
85-
assert len(folder_gcd_files) <= 1
86-
if len(folder_gcd_files) == 0:
87-
assert gcd_rescue is not None
88-
folder_gcd_files = [gcd_rescue]
89-
90-
# Store list of I3 files and corresponding GCD files.
91-
folder_gcd_files = folder_gcd_files * len(folder_i3_files)
92-
93-
gcd_files.extend(folder_gcd_files)
94-
i3_files.extend(folder_i3_files)
95-
96-
return i3_files, gcd_files
62+
if all([is_i3_file(input) for input in inputs]):
63+
print("Assuming list of files.")
64+
assert gcd_rescue is not None
65+
gcd_files = [gcd_rescue] * len(inputs)
66+
return inputs, gcd_files
67+
68+
elif all(os.path.isdir(input) for input in inputs):
69+
print("Assuming list of directories.")
70+
71+
for directory in inputs:
72+
73+
# Find all I3-like files in `directory`.
74+
paths = []
75+
i3_patterns = ["*.bz2", "*.zst", "*.gz"]
76+
for i3_pattern in i3_patterns:
77+
if recursive:
78+
paths.extend(list(Path(directory).rglob(i3_pattern)))
79+
else:
80+
paths.extend(list(Path(directory).glob(i3_pattern)))
81+
82+
# Loop over all folders containing such I3-like files.
83+
folders = sorted(set([path.parent for path in paths]))
84+
for folder in folders:
85+
86+
# List all I3 and GCD files, in the current folder.
87+
folder_files = [
88+
str(path) for path in paths if path.parent == folder
89+
]
90+
folder_i3_files = list(filter(is_i3_file, folder_files))
91+
folder_gcd_files = list(filter(is_gcd_file, folder_files))
92+
93+
# Make sure that no more than one GCD file is found;
94+
# and use rescue file if none is found.
95+
assert len(folder_gcd_files) <= 1
96+
if len(folder_gcd_files) == 0:
97+
assert gcd_rescue is not None
98+
folder_gcd_files = [gcd_rescue]
99+
100+
# Store list of I3 files and corresponding GCD files.
101+
folder_gcd_files = folder_gcd_files * len(folder_i3_files)
102+
103+
gcd_files.extend(folder_gcd_files)
104+
i3_files.extend(folder_i3_files)
105+
return i3_files, gcd_files
106+
else:
107+
if any([os.path.isdir(input) for input in inputs]):
108+
raise ValueError(
109+
"Inputs contains a mix of files and directories \
110+
which is not supported."
111+
)
112+
else:
113+
raise ValueError("Some inputs are not valid directories or files.")

0 commit comments

Comments
 (0)