1313# limitations under the License.
1414"""Utilities to sample csv files.
1515
16- To sample a CSV data file, run the command:
16+ To sample 100 rows from a CSV data file, run the command:
1717 python data_sampler.py --sampler_input=<input-csv> --sampler_output=<output-csv>
1818
19+ To sample all unique values from a set of columns, run the command
20+ with the additional options:
21+ --sampler_uniques_per_column=-1
22+ --sampler_output_rows=-1
23+
1924This generates a sample output CSV with atmost 100 rows selecting input rows
2025with unique column values.
2126
5459flags .DEFINE_integer (
5560 'sampler_rows_per_key' , 5 ,
5661 'The maximum number of rows to select for each unique value found.' )
62+ flags .DEFINE_integer (
63+ 'sampler_uniques_per_column' , 10 ,
64+ 'The maximum number of unique values to track per column. '
65+ 'If 0 or -1, all unique values are tracked.' )
5766flags .DEFINE_float (
5867 'sampler_rate' , - 1 ,
5968 'The sampling rate for random row selection (e.g., 0.1 for 10%).' )
6574flags .DEFINE_string (
6675 'sampler_unique_columns' , '' ,
6776 'A comma-separated list of column names to use for selecting unique rows.' )
77+ flags .DEFINE_list (
78+ 'sampler_column_keys' , [],
79+ 'A list of "column:file" pairs containing values that MUST be included '
80+ 'in the sample if they appear in the input data. '
81+ 'If empty (default), sampling is based on sampler_uniques_per_column.'
82+ 'Example: "variableMeasured:prominent_svs.txt"' )
6883flags .DEFINE_string ('sampler_input_delimiter' , ',' ,
6984 'The delimiter used in the input CSV file.' )
7085flags .DEFINE_string ('sampler_input_encoding' , 'UTF8' ,
7590_FLAGS = flags .FLAGS
7691
7792import file_util
93+ import mcf_file_util
7894
7995from config_map import ConfigMap
8096from counters import Counters
@@ -103,27 +119,35 @@ def __init__(
103119 self ,
104120 config_dict : dict = None ,
105121 counters : Counters = None ,
122+ column_include_values : dict = None ,
106123 ):
107124 """Initializes the DataSampler object.
108125
109126 Args:
110127 config_dict: A dictionary of configuration parameters.
111128 counters: A Counters object for tracking statistics.
129+ column_include_values: a dictionary of column-name to set of values
130+ in the column to be included in the sample
112131 """
113132 self ._config = ConfigMap (config_dict = get_default_config ())
114133 if config_dict :
115134 self ._config .add_configs (config_dict )
116135 self ._counters = counters if counters is not None else Counters ()
136+ self ._column_include_values = column_include_values
117137 self .reset ()
118138
119139 def reset (self ) -> None :
120140 """Resets the state of the DataSampler.
121141
122142 This method resets the internal state of the DataSampler, including the
123- counts of unique column values and the number of selected rows. This is
124- useful when you want to reuse the same DataSampler instance for sampling
125- multiple files.
143+ counts of unique column values and the number of selected rows.
126144 """
145+ if self ._config .get ('sampler_exhaustive' ):
146+ # Exhaustive mode overrides limits to capture all unique values.
147+ self ._config .set_config ('sampler_output_rows' , - 1 )
148+ self ._config .set_config ('sampler_uniques_per_column' , - 1 )
149+ self ._config .set_config ('sampler_rows_per_key' , 1 )
150+
127151 # Dictionary of unique values: count per column
128152 self ._column_counts = {}
129153 # Dictionary of column index: list of header strings
@@ -144,6 +168,16 @@ def reset(self) -> None:
144168 if col .strip ()
145169 ]
146170
171+ # Must include values: dict of column_name -> set of values
172+ self ._must_include_values = load_column_keys (
173+ self ._config .get ('sampler_column_keys' , []))
174+ if self ._column_include_values :
175+ for col , vals in self ._column_include_values .items ():
176+ self ._must_include_values .setdefault (col , set ()).update (vals )
177+
178+ # Map of column index -> set of values
179+ self ._must_include_indices = {}
180+
147181 def __del__ (self ) -> None :
148182 """Logs the column headers and counts upon object deletion."""
149183 logging .log (2 , f'Sampler column headers: { self ._column_headers } ' )
@@ -179,22 +213,35 @@ def _get_column_count(self, column_index: int, value: str) -> int:
179213 return 0
180214 return col_values .get (value , 0 )
181215
182- def _should_track_column (self , column_index : int ) -> bool :
183- """Determines if a column should be tracked for unique values .
216+ def _is_unique_column (self , column_index : int ) -> bool :
217+ """Determines if a column is specified for unique value sampling .
184218
185219 Args:
186220 column_index: The index of the column.
187221
188222 Returns:
189- True if the column should be tracked (either no unique columns
190- specified or this column is in the unique columns list).
223+ True if the column should be sampled for unique values.
191224 """
192225 if not self ._unique_column_names :
193- # No specific columns specified, track all
226+ # No specific columns specified, track all for unique sampling
194227 return True
195228 # Check if this column is in our unique columns
196229 return column_index in self ._unique_column_indices .values ()
197230
231+ def _should_track_column (self , column_index : int ) -> bool :
232+ """Determines if a column should be tracked for counts.
233+
234+ Args:
235+ column_index: The index of the column.
236+
237+ Returns:
238+ True if the column should be tracked for unique values or is a
239+ must-include column.
240+ """
241+ if self ._is_unique_column (column_index ):
242+ return True
243+ return column_index in self ._must_include_indices
244+
198245 def _process_header_row (self , row : list [str ]) -> None :
199246 """Process a header row to build column name to index mapping.
200247
@@ -206,15 +253,27 @@ def _process_header_row(self, row: list[str]) -> None:
206253 Args:
207254 row: A header row containing column names.
208255 """
209- if not self ._unique_column_names :
210- return
211-
212256 for index , column_name in enumerate (row ):
213- if column_name in self ._unique_column_names :
257+ if self . _unique_column_names and column_name in self ._unique_column_names :
214258 self ._unique_column_indices [column_name ] = index
215259 logging .level_debug () and logging .debug (
216260 f'Mapped unique column "{ column_name } " to index { index } ' )
217261
262+ if self ._must_include_values and column_name in self ._must_include_values :
263+ self ._must_include_indices [index ] = self ._must_include_values [
264+ column_name ]
265+ logging .info (
266+ f'Mapped must-include column "{ column_name } " to index { index } '
267+ )
268+
269+ def _is_must_include (self , column_index : int , value : str ) -> bool :
270+ """Checks if a column value is in the must-include list."""
271+ if column_index not in self ._must_include_indices :
272+ return False
273+ # Normalize the input value before checking against the set
274+ return mcf_file_util .strip_namespace (
275+ value ) in self ._must_include_indices [column_index ]
276+
218277 def _add_column_header (self , column_index : int , value : str ) -> str :
219278 """Adds the first non-empty value of a column as its header.
220279
@@ -282,13 +341,26 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
282341 # Too many rows already selected. Drop it.
283342 return False
284343 max_count = self ._config .get ('sampler_rows_per_key' , 3 )
344+ if max_count <= 0 :
345+ max_count = sys .maxsize
285346 max_uniques_per_col = self ._config .get ('sampler_uniques_per_column' , 10 )
347+ if max_uniques_per_col <= 0 :
348+ max_uniques_per_col = sys .maxsize
349+
286350 for index in range (len (row )):
287- # Skip columns not in unique_columns list
288- if not self ._should_track_column (index ):
289- continue
290351 value = row [index ]
291352 value_count = self ._get_column_count (index , value )
353+
354+ # Rule 1: Always include if it's a must-include value and
355+ # we haven't reached per-key limit.
356+ if value_count < max_count and self ._is_must_include (index , value ):
357+ self ._counters .add_counter ('sampler-selected-must-include' , 1 )
358+ return True
359+
360+ # Skip columns not in unique_columns list for general unique sampling
361+ if not self ._is_unique_column (index ):
362+ continue
363+
292364 if value_count == 0 or value_count < max_count :
293365 # This is a new value for this column.
294366 col_counts = self ._column_counts .get (index , {})
@@ -301,7 +373,7 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
301373 # No new unique value for the row.
302374 # Check random sampler.
303375 if sample_rate < 0 :
304- sample_rate = self ._config .get ('sampler_rate' )
376+ sample_rate = self ._config .get ('sampler_rate' , - 1 )
305377 if random .random () <= sample_rate :
306378 self ._counters .add_counter ('sampler-sampled-rows' , 1 )
307379 return True
@@ -390,6 +462,7 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str:
390462 # Process and write header rows from the first input file.
391463 if row_index <= header_rows and input_index == 0 :
392464 self ._process_header_row (row )
465+ self ._add_row_counts (row )
393466 csv_writer .writerow (row )
394467 self ._counters .add_counter ('sampler-header-rows' , 1 )
395468 # After processing all header rows, validate that all
@@ -425,9 +498,41 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str:
425498 return output_file
426499
427500
501+ def load_column_keys (column_keys : list ) -> dict :
502+ """Returns a dictionary of column name to set of keys loaded from a file.
503+ The set of keys for a column are used as filter when sampling.
504+
505+ Args:
506+ column_keys: comma separated list of column_name:<csv file> with
507+ first column as the keys to be loaded.
508+
509+ Returns:
510+ dictionary of column name to a set of keys for that column
511+ { <column-name1>: { key1, key2, ...}, <column-name2>: { k1, k2...} ...}
512+ """
513+ column_map = {}
514+ if not isinstance (column_keys , list ):
515+ column_keys = column_keys .split (',' )
516+
517+ for col_file in column_keys :
518+ column_name , file_name = col_file .split (':' , 1 )
519+ if not file_name :
520+ logging .error (f'No file for column { column_name } in { column_keys } ' )
521+ continue
522+
523+ col_items = file_util .file_load_csv_dict (file_name )
524+ column_map [column_name ] = set (
525+ {mcf_file_util .strip_namespace (val ) for val in col_items .keys ()})
526+ logging .info (
527+ f'Loaded { len (col_items )} for column { column_name } from { file_name } '
528+ )
529+ return column_map
530+
531+
428532def sample_csv_file (input_file : str ,
429533 output_file : str = '' ,
430- config : dict = None ) -> str :
534+ config : dict = None ,
535+ counters : Counters = None ) -> str :
431536 """Samples a CSV file and returns the path to the sampled file.
432537
433538 This function provides a convenient way to sample a CSV file with a single
@@ -455,6 +560,7 @@ def sample_csv_file(input_file: str,
455560 - input_delimiter: The delimiter used in the input file.
456561 - output_delimiter: The delimiter to use in the output file.
457562 - input_encoding: The encoding of the input file.
563+ counters: optional Counters object to get counts of sampling.
458564
459565 Returns:
460566 The path to the output file with the sampled rows.
@@ -484,7 +590,7 @@ def sample_csv_file(input_file: str,
484590 """
485591 if config is None :
486592 config = {}
487- data_sampler = DataSampler (config_dict = config )
593+ data_sampler = DataSampler (config_dict = config , counters = counters )
488594 return data_sampler .sample_csv_file (input_file , output_file )
489595
490596
@@ -500,23 +606,32 @@ def get_default_config() -> dict:
500606 # Use default values of flags for tests
501607 if not _FLAGS .is_parsed ():
502608 _FLAGS .mark_as_parsed ()
503- return {
609+
610+ config = {
504611 'sampler_rate' : _FLAGS .sampler_rate ,
505612 'sampler_input' : _FLAGS .sampler_input ,
506613 'sampler_output' : _FLAGS .sampler_output ,
507614 'sampler_output_rows' : _FLAGS .sampler_output_rows ,
508615 'header_rows' : _FLAGS .sampler_header_rows ,
509616 'sampler_rows_per_key' : _FLAGS .sampler_rows_per_key ,
617+ 'sampler_uniques_per_column' : _FLAGS .sampler_uniques_per_column ,
510618 'sampler_column_regex' : _FLAGS .sampler_column_regex ,
511619 'sampler_unique_columns' : _FLAGS .sampler_unique_columns ,
620+ 'sampler_column_keys' : _FLAGS .sampler_column_keys ,
512621 'input_delimiter' : _FLAGS .sampler_input_delimiter ,
513622 'output_delimiter' : _FLAGS .sampler_output_delimiter ,
514623 'input_encoding' : _FLAGS .sampler_input_encoding ,
515624 }
516625
626+ return config
627+
517628
518629def main (_ ):
519- sample_csv_file (_FLAGS .sampler_input , _FLAGS .sampler_output )
630+ counters = Counters ()
631+ sample_csv_file (_FLAGS .sampler_input ,
632+ _FLAGS .sampler_output ,
633+ counters = counters )
634+ counters .print_counters ()
520635
521636
522637if __name__ == '__main__' :
0 commit comments