66
77import os
88from tqdm import tqdm
9- from typing import List , Dict , Optional
9+ from typing import List , Dict , Optional , Union , Tuple
1010
1111from graphnet .data .utilities import (
1212 create_table_and_save_to_sql ,
@@ -100,6 +100,8 @@ def merge_files(
100100 files : List [str ],
101101 output_dir : str ,
102102 primary_key_rescue : str = "event_no" ,
103+ remove_originals : bool = False ,
104+ reset_integer_primary_key : bool = False ,
103105 ) -> None :
104106 """SQLite-specific method for merging output files/databases.
105107
@@ -117,6 +119,12 @@ def merge_files(
117119 primary_key_rescue: The name of the columns on which the primary
118120 key is constructed. This will only be used if it is not
119121 possible to infer the primary key name.
122+ remove_originals: If True, the original files will be removed
123+ after merging.
124+ reset_integer_primary_key: If True, the primary key will reset
125+ when merging the databases. This is useful if the files
126+ were not generated by the same process and therefore
127+ have overlapping index values.
120128 """
121129 # Warnings
122130 if self ._max_table_size :
@@ -136,14 +144,21 @@ def merge_files(
136144 if len (files ) > 0 :
137145 os .makedirs (output_dir , exist_ok = True )
138146 self .info (f"Merging { len (files )} database files" )
139- self ._merge_databases (files = files , database_path = database_path )
147+ self ._merge_databases (
148+ files = files ,
149+ database_path = database_path ,
150+ remove_originals = remove_originals ,
151+ reset_integer_primary_key = reset_integer_primary_key ,
152+ )
140153 else :
141154 self .warning ("No database files given! Exiting." )
142155
143156 def _merge_databases (
144157 self ,
145158 files : List [str ],
146159 database_path : str ,
160+ remove_originals : bool = False ,
161+ reset_integer_primary_key : bool = False ,
147162 ) -> None :
148163 """Merge the temporary databases.
149164
@@ -152,6 +167,10 @@ def _merge_databases(
152167 database_path: Path to a database, can be an empty path, where the
153168 databases listed in `files` will be merged into. If no database
154169 exists at the given path, one will be created.
170+ remove_originals: If True, the original files will be removed
171+ after merging.
172+ reset_integer_primary_key: If True, the primary key will reset
173+ when merging the databases.
155174 """
156175 if os .path .exists (database_path ):
157176 self .warning (
@@ -165,6 +184,7 @@ def _merge_databases(
165184 self ._largest_table = 0
166185
167186 # Merge temporary databases into newly created one
187+ start_count = 0
168188 for file_count , input_file in tqdm (enumerate (files ), colour = "green" ):
169189 # Extract table names and index column name in database
170190 try :
@@ -177,7 +197,18 @@ def _merge_databases(
177197 continue
178198 else :
179199 raise e
180-
200+ integer_primary_dict = None
201+ if reset_integer_primary_key :
202+ if list (tables .values ())[0 ] is None :
203+ # Ensure that the first table is indexed
204+ temp_tables : Dict [str , Union [str , None ]] = {
205+ k : v for k , v in tables .items () if v is not None
206+ }
207+ # re-add the non-indexed tables to the dictionary
208+ temp_tables .update (
209+ {k : v for k , v in tables .items () if v is None }
210+ )
211+ tables = temp_tables
181212 for table_name in tables .keys ():
182213 # Extract all data in the table from the given database
183214 df = query_database (
@@ -190,6 +221,22 @@ def _merge_databases(
190221 True if tables [table_name ] is not None else False
191222 )
192223
224+ if reset_integer_primary_key & (integer_primary_dict is None ):
225+ integer_primary_dict , start_count = (
226+ self ._generate_integer_primary_dict (
227+ start_count , df [primary_key ].values
228+ )
229+ )
230+ df = self ._map_integer_primary_keys (
231+ df , primary_key , integer_primary_dict
232+ )
233+ elif reset_integer_primary_key & (
234+ integer_primary_dict is not None
235+ ):
236+ df = self ._map_integer_primary_keys (
237+ df , primary_key , integer_primary_dict
238+ )
239+
193240 # Submit to new database
194241 create_table_and_save_to_sql (
195242 df = df ,
@@ -216,8 +263,31 @@ def _merge_databases(
216263 "Maximum row count reached."
217264 f" Creating new partition at { database_path } "
218265 )
266+ if remove_originals :
267+ # Remove original files after all files have been merged
268+ for input_file in files :
269+ os .remove (input_file )
219270
220271 # Internal methods
272+ def _generate_integer_primary_dict (
273+ self , start_count : int , integer_primary_keys_array : List [int ]
274+ ) -> Tuple [Dict [int , int ], int ]:
275+ integer_primary_dict = {}
276+ for i , key in enumerate (integer_primary_keys_array ):
277+ integer_primary_dict [key ] = start_count + i
278+ return integer_primary_dict , start_count + len (
279+ integer_primary_keys_array
280+ )
281+
282+ def _map_integer_primary_keys (
283+ self ,
284+ df : pd .DataFrame ,
285+ key : str ,
286+ integer_primary_dict : Union [Dict [int , int ], None ],
287+ ) -> pd .DataFrame :
288+ assert integer_primary_dict is not None # mypy...
289+ df [key ] = df [key ].map (integer_primary_dict )
290+ return df
221291
222292 def _adjust_output_path (self , output_file : str ) -> str :
223293 """Adjust the file path to reflect that it is a partition."""
0 commit comments