Skip to content

Commit fe01035

Browse files
authored
Enhancements to data sampler (#1916)
* Support 0 eval results * Enhancements to data sampler * gemini rebiew comments * print counters * review comments * fix test
1 parent a26161e commit fe01035

2 files changed

Lines changed: 279 additions & 37 deletions

File tree

tools/statvar_importer/data_sampler.py

Lines changed: 136 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@
1313
# limitations under the License.
1414
"""Utilities to sample csv files.
1515
16-
To sample a CSV data file, run the command:
16+
To sample 100 rows from a CSV data file, run the command:
1717
python data_sampler.py --sampler_input=<input-csv> --sampler_output=<output-csv>
1818
19+
To sample all unique values from a set of columns, run the command
20+
with the additional options:
21+
--sampler_uniques_per_column=-1
22+
--sampler_output_rows=-1
23+
1924
This generates a sample output CSV with atmost 100 rows selecting input rows
2025
with unique column values.
2126
@@ -54,6 +59,10 @@
5459
flags.DEFINE_integer(
5560
'sampler_rows_per_key', 5,
5661
'The maximum number of rows to select for each unique value found.')
62+
flags.DEFINE_integer(
63+
'sampler_uniques_per_column', 10,
64+
'The maximum number of unique values to track per column. '
65+
'If 0 or -1, all unique values are tracked.')
5766
flags.DEFINE_float(
5867
'sampler_rate', -1,
5968
'The sampling rate for random row selection (e.g., 0.1 for 10%).')
@@ -65,6 +74,12 @@
6574
flags.DEFINE_string(
6675
'sampler_unique_columns', '',
6776
'A comma-separated list of column names to use for selecting unique rows.')
77+
flags.DEFINE_list(
78+
'sampler_column_keys', [],
79+
'A list of "column:file" pairs containing values that MUST be included '
80+
'in the sample if they appear in the input data. '
81+
'If empty (default), sampling is based on sampler_uniques_per_column.'
82+
'Example: "variableMeasured:prominent_svs.txt"')
6883
flags.DEFINE_string('sampler_input_delimiter', ',',
6984
'The delimiter used in the input CSV file.')
7085
flags.DEFINE_string('sampler_input_encoding', 'UTF8',
@@ -75,6 +90,7 @@
7590
_FLAGS = flags.FLAGS
7691

7792
import file_util
93+
import mcf_file_util
7894

7995
from config_map import ConfigMap
8096
from counters import Counters
@@ -103,27 +119,35 @@ def __init__(
103119
self,
104120
config_dict: dict = None,
105121
counters: Counters = None,
122+
column_include_values: dict = None,
106123
):
107124
"""Initializes the DataSampler object.
108125
109126
Args:
110127
config_dict: A dictionary of configuration parameters.
111128
counters: A Counters object for tracking statistics.
129+
column_include_values: a dictionary of column-name to set of values
130+
in the column to be included in the sample
112131
"""
113132
self._config = ConfigMap(config_dict=get_default_config())
114133
if config_dict:
115134
self._config.add_configs(config_dict)
116135
self._counters = counters if counters is not None else Counters()
136+
self._column_include_values = column_include_values
117137
self.reset()
118138

119139
def reset(self) -> None:
120140
"""Resets the state of the DataSampler.
121141
122142
This method resets the internal state of the DataSampler, including the
123-
counts of unique column values and the number of selected rows. This is
124-
useful when you want to reuse the same DataSampler instance for sampling
125-
multiple files.
143+
counts of unique column values and the number of selected rows.
126144
"""
145+
if self._config.get('sampler_exhaustive'):
146+
# Exhaustive mode overrides limits to capture all unique values.
147+
self._config.set_config('sampler_output_rows', -1)
148+
self._config.set_config('sampler_uniques_per_column', -1)
149+
self._config.set_config('sampler_rows_per_key', 1)
150+
127151
# Dictionary of unique values: count per column
128152
self._column_counts = {}
129153
# Dictionary of column index: list of header strings
@@ -144,6 +168,16 @@ def reset(self) -> None:
144168
if col.strip()
145169
]
146170

171+
# Must include values: dict of column_name -> set of values
172+
self._must_include_values = load_column_keys(
173+
self._config.get('sampler_column_keys', []))
174+
if self._column_include_values:
175+
for col, vals in self._column_include_values.items():
176+
self._must_include_values.setdefault(col, set()).update(vals)
177+
178+
# Map of column index -> set of values
179+
self._must_include_indices = {}
180+
147181
def __del__(self) -> None:
148182
"""Logs the column headers and counts upon object deletion."""
149183
logging.log(2, f'Sampler column headers: {self._column_headers}')
@@ -179,22 +213,35 @@ def _get_column_count(self, column_index: int, value: str) -> int:
179213
return 0
180214
return col_values.get(value, 0)
181215

182-
def _should_track_column(self, column_index: int) -> bool:
183-
"""Determines if a column should be tracked for unique values.
216+
def _is_unique_column(self, column_index: int) -> bool:
217+
"""Determines if a column is specified for unique value sampling.
184218
185219
Args:
186220
column_index: The index of the column.
187221
188222
Returns:
189-
True if the column should be tracked (either no unique columns
190-
specified or this column is in the unique columns list).
223+
True if the column should be sampled for unique values.
191224
"""
192225
if not self._unique_column_names:
193-
# No specific columns specified, track all
226+
# No specific columns specified, track all for unique sampling
194227
return True
195228
# Check if this column is in our unique columns
196229
return column_index in self._unique_column_indices.values()
197230

231+
def _should_track_column(self, column_index: int) -> bool:
232+
"""Determines if a column should be tracked for counts.
233+
234+
Args:
235+
column_index: The index of the column.
236+
237+
Returns:
238+
True if the column should be tracked for unique values or is a
239+
must-include column.
240+
"""
241+
if self._is_unique_column(column_index):
242+
return True
243+
return column_index in self._must_include_indices
244+
198245
def _process_header_row(self, row: list[str]) -> None:
199246
"""Process a header row to build column name to index mapping.
200247
@@ -206,15 +253,27 @@ def _process_header_row(self, row: list[str]) -> None:
206253
Args:
207254
row: A header row containing column names.
208255
"""
209-
if not self._unique_column_names:
210-
return
211-
212256
for index, column_name in enumerate(row):
213-
if column_name in self._unique_column_names:
257+
if self._unique_column_names and column_name in self._unique_column_names:
214258
self._unique_column_indices[column_name] = index
215259
logging.level_debug() and logging.debug(
216260
f'Mapped unique column "{column_name}" to index {index}')
217261

262+
if self._must_include_values and column_name in self._must_include_values:
263+
self._must_include_indices[index] = self._must_include_values[
264+
column_name]
265+
logging.info(
266+
f'Mapped must-include column "{column_name}" to index {index}'
267+
)
268+
269+
def _is_must_include(self, column_index: int, value: str) -> bool:
270+
"""Checks if a column value is in the must-include list."""
271+
if column_index not in self._must_include_indices:
272+
return False
273+
# Normalize the input value before checking against the set
274+
return mcf_file_util.strip_namespace(
275+
value) in self._must_include_indices[column_index]
276+
218277
def _add_column_header(self, column_index: int, value: str) -> str:
219278
"""Adds the first non-empty value of a column as its header.
220279
@@ -282,13 +341,26 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
282341
# Too many rows already selected. Drop it.
283342
return False
284343
max_count = self._config.get('sampler_rows_per_key', 3)
344+
if max_count <= 0:
345+
max_count = sys.maxsize
285346
max_uniques_per_col = self._config.get('sampler_uniques_per_column', 10)
347+
if max_uniques_per_col <= 0:
348+
max_uniques_per_col = sys.maxsize
349+
286350
for index in range(len(row)):
287-
# Skip columns not in unique_columns list
288-
if not self._should_track_column(index):
289-
continue
290351
value = row[index]
291352
value_count = self._get_column_count(index, value)
353+
354+
# Rule 1: Always include if it's a must-include value and
355+
# we haven't reached per-key limit.
356+
if value_count < max_count and self._is_must_include(index, value):
357+
self._counters.add_counter('sampler-selected-must-include', 1)
358+
return True
359+
360+
# Skip columns not in unique_columns list for general unique sampling
361+
if not self._is_unique_column(index):
362+
continue
363+
292364
if value_count == 0 or value_count < max_count:
293365
# This is a new value for this column.
294366
col_counts = self._column_counts.get(index, {})
@@ -301,7 +373,7 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
301373
# No new unique value for the row.
302374
# Check random sampler.
303375
if sample_rate < 0:
304-
sample_rate = self._config.get('sampler_rate')
376+
sample_rate = self._config.get('sampler_rate', -1)
305377
if random.random() <= sample_rate:
306378
self._counters.add_counter('sampler-sampled-rows', 1)
307379
return True
@@ -390,6 +462,7 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str:
390462
# Process and write header rows from the first input file.
391463
if row_index <= header_rows and input_index == 0:
392464
self._process_header_row(row)
465+
self._add_row_counts(row)
393466
csv_writer.writerow(row)
394467
self._counters.add_counter('sampler-header-rows', 1)
395468
# After processing all header rows, validate that all
@@ -425,9 +498,41 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str:
425498
return output_file
426499

427500

501+
def load_column_keys(column_keys: list) -> dict:
502+
"""Returns a dictionary of column name to set of keys loaded from a file.
503+
The set of keys for a column are used as filter when sampling.
504+
505+
Args:
506+
column_keys: comma separated list of column_name:<csv file> with
507+
first column as the keys to be loaded.
508+
509+
Returns:
510+
dictionary of column name to a set of keys for that column
511+
{ <column-name1>: { key1, key2, ...}, <column-name2>: { k1, k2...} ...}
512+
"""
513+
column_map = {}
514+
if not isinstance(column_keys, list):
515+
column_keys = column_keys.split(',')
516+
517+
for col_file in column_keys:
518+
column_name, file_name = col_file.split(':', 1)
519+
if not file_name:
520+
logging.error(f'No file for column {column_name} in {column_keys}')
521+
continue
522+
523+
col_items = file_util.file_load_csv_dict(file_name)
524+
column_map[column_name] = set(
525+
{mcf_file_util.strip_namespace(val) for val in col_items.keys()})
526+
logging.info(
527+
f'Loaded {len(col_items)} for column {column_name} from {file_name}'
528+
)
529+
return column_map
530+
531+
428532
def sample_csv_file(input_file: str,
429533
output_file: str = '',
430-
config: dict = None) -> str:
534+
config: dict = None,
535+
counters: Counters = None) -> str:
431536
"""Samples a CSV file and returns the path to the sampled file.
432537
433538
This function provides a convenient way to sample a CSV file with a single
@@ -455,6 +560,7 @@ def sample_csv_file(input_file: str,
455560
- input_delimiter: The delimiter used in the input file.
456561
- output_delimiter: The delimiter to use in the output file.
457562
- input_encoding: The encoding of the input file.
563+
counters: optional Counters object to get counts of sampling.
458564
459565
Returns:
460566
The path to the output file with the sampled rows.
@@ -484,7 +590,7 @@ def sample_csv_file(input_file: str,
484590
"""
485591
if config is None:
486592
config = {}
487-
data_sampler = DataSampler(config_dict=config)
593+
data_sampler = DataSampler(config_dict=config, counters=counters)
488594
return data_sampler.sample_csv_file(input_file, output_file)
489595

490596

@@ -500,23 +606,32 @@ def get_default_config() -> dict:
500606
# Use default values of flags for tests
501607
if not _FLAGS.is_parsed():
502608
_FLAGS.mark_as_parsed()
503-
return {
609+
610+
config = {
504611
'sampler_rate': _FLAGS.sampler_rate,
505612
'sampler_input': _FLAGS.sampler_input,
506613
'sampler_output': _FLAGS.sampler_output,
507614
'sampler_output_rows': _FLAGS.sampler_output_rows,
508615
'header_rows': _FLAGS.sampler_header_rows,
509616
'sampler_rows_per_key': _FLAGS.sampler_rows_per_key,
617+
'sampler_uniques_per_column': _FLAGS.sampler_uniques_per_column,
510618
'sampler_column_regex': _FLAGS.sampler_column_regex,
511619
'sampler_unique_columns': _FLAGS.sampler_unique_columns,
620+
'sampler_column_keys': _FLAGS.sampler_column_keys,
512621
'input_delimiter': _FLAGS.sampler_input_delimiter,
513622
'output_delimiter': _FLAGS.sampler_output_delimiter,
514623
'input_encoding': _FLAGS.sampler_input_encoding,
515624
}
516625

626+
return config
627+
517628

518629
def main(_):
519-
sample_csv_file(_FLAGS.sampler_input, _FLAGS.sampler_output)
630+
counters = Counters()
631+
sample_csv_file(_FLAGS.sampler_input,
632+
_FLAGS.sampler_output,
633+
counters=counters)
634+
counters.print_counters()
520635

521636

522637
if __name__ == '__main__':

0 commit comments

Comments
 (0)