From 47d5b0862a69bfd21740ab32247f251a9b509452 Mon Sep 17 00:00:00 2001 From: Aireen Mei Date: Wed, 3 Jun 2026 21:37:00 -0700 Subject: [PATCH] Support non-sequence meta-features in PyGrain packing transformations. Previously, only sequence meta-features (defined in `length_struct`) were supported. Non-sequence meta-features (in `meta_features` but not in `length_struct`) were stripped or caused errors during packing. This change adds support for non-sequence meta-features in both FirstFit and BestFit packing methods: - Python implementation (`PackedBatch`): Non-sequence meta-features are identified, accumulated as lists during packing, and yielded as 1D numpy object arrays of lists. - Python Iterator (`PackingDatasetIterator`): The `_combined_struct` is updated to include non-sequence meta-features so they are not stripped from input elements. - Refactored common key-extraction logic into shared helper functions in `packing_packed_batch.py`. - Added unit tests in `testing_util.py` to verify FirstFit and BestFit with non-sequence meta-features (both fixed and variable shapes). - Updated docstrings in `packing.py` to document the behavior of sequence vs non-sequence meta-features. PiperOrigin-RevId: 926448116 --- CHANGELOG.md | 1 + .../_src/python/dataset/transformations/BUILD | 1 + .../python/dataset/transformations/packing.py | 26 ++++- .../transformations/packing_packed_batch.py | 73 ++++++++++-- .../dataset/transformations/testing_util.py | 109 +++++++++++++++++- 5 files changed, 198 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f68835e14..204d2ef16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change * Adds profiling of multiprocess workers when using XProf profiler. To enable, set flag `grain_enable_multiprocess_worker_profiling=true` and add `"profile_subprocesses" = True` in advanced profiler options. + * Adds support for non-sequence meta-features in packing transformations. * Breaking changes: diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index af3e71a19..fae3190d2 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -94,6 +94,7 @@ py_library( srcs_version = "PY3", deps = [ ":packing", + ":packing_packed_batch", "//grain/_src/core:tree_lib", "//grain/_src/python/dataset", "@abseil-py//absl/testing:absltest", diff --git a/grain/_src/python/dataset/transformations/packing.py b/grain/_src/python/dataset/transformations/packing.py index a71396112..29f7a18c8 100644 --- a/grain/_src/python/dataset/transformations/packing.py +++ b/grain/_src/python/dataset/transformations/packing.py @@ -14,6 +14,7 @@ """Implements packing transformations.""" from collections.abc import Sequence +import copy from typing import Any, Type from grain._src.core import tree_lib @@ -75,6 +76,11 @@ def __init__( self._pack_alignment_struct = pack_alignment_struct self._padding_struct = padding_struct self._max_sequences_per_bin = max_sequences_per_bin + self._combined_struct = copy.copy(length_struct) + if isinstance(self._combined_struct, dict): + for k in meta_features: + if k not in self._combined_struct: + self._combined_struct[k] = None self._reset() def _reset(self): @@ -217,7 +223,7 @@ def __next__(self): with timer: # Remove elements not in packing struct. element = tree_lib.map_structure_up_to( - self._length_struct, lambda x: x, element + self._combined_struct, lambda x: x, element ) if self._current_batch is None: @@ -287,7 +293,11 @@ def __init__( seed: Random seed for shuffling bins. shuffle_bins: Whether to shuffle bins after packing. shuffle_bins_group_by_feature: Feature to group by for shuffling. - meta_features: Meta features that do not need packing logic. + meta_features: Meta features that do not need packing logic. They can be + sequence meta-features (if present in `length_struct`, packed and padded + to target length) or non-sequence meta-features (if not present in + `length_struct`, returned as a list of meta-features from the packed + examples in each bin). pack_alignment_struct: Optional per-feature alignment values. padding_struct: Optional per-feature padding values. max_sequences_per_bin: Optional maximum number of input sequences that can @@ -367,7 +377,11 @@ def __init__( seed: Random seed for shuffling bins. shuffle_bins: Whether to shuffle bins after packing. shuffle_bins_group_by_feature: Feature to group by for shuffling. - meta_features: Meta features that do not need packing logic. + meta_features: Meta features that do not need packing logic. They can be + sequence meta-features (if present in `length_struct`, packed and padded + to target length) or non-sequence meta-features (if not present in + `length_struct`, returned as a list of meta-features from the packed + examples in each bin). pack_alignment_struct: Optional per-feature alignment values. padding_struct: Optional per-feature padding values. max_sequences_per_bin: Optional maximum number of input sequences that can @@ -424,7 +438,11 @@ def __init__( seed: Random seed for shuffling bins. shuffle_bins: Whether to shuffle bins after packing. shuffle_bins_group_by_feature: Feature to group by for shuffling. - meta_features: Meta features that do not need packing logic. + meta_features: Meta features that do not need packing logic. They can be + sequence meta-features (if present in `length_struct`, packed and padded + to target length) or non-sequence meta-features (if not present in + `length_struct`, returned as a list of meta-features from the packed + examples in each bin). pack_alignment_struct: Optional per-feature alignment values. padding_struct: Optional per-feature padding values. max_sequences_per_bin: Optional maximum number of input sequences that can diff --git a/grain/_src/python/dataset/transformations/packing_packed_batch.py b/grain/_src/python/dataset/transformations/packing_packed_batch.py index 02271d35b..56f8ae1ee 100644 --- a/grain/_src/python/dataset/transformations/packing_packed_batch.py +++ b/grain/_src/python/dataset/transformations/packing_packed_batch.py @@ -47,6 +47,29 @@ _T = TypeVar("_T") +def get_key_name(path_element: Any) -> str: + if hasattr(path_element, "key"): + return path_element.key + if hasattr(path_element, "name"): + return path_element.name + return path_element + + +def get_length_struct_keys(length_struct: Any) -> set[str]: + return { + get_key_name(p[0]) + for (p, _) in tree_lib.flatten_with_path(length_struct) + if p + } + + +def get_non_sequence_meta_features( + length_struct: Any, meta_features: Sequence[str] +) -> list[str]: + keys = get_length_struct_keys(length_struct) + return [k for k in meta_features if k not in keys] + + @dataclasses.dataclass(frozen=True, kw_only=True) class SuccessfulRowOrFailingComponents: # Holds the index of the row to put a new element into if it can fit, @@ -118,6 +141,14 @@ def __init__( self._size_bytes = 0 self._max_sequences_per_bin = max_sequences_per_bin + self._non_sequence_meta_features = get_non_sequence_meta_features( + length_struct, meta_features + ) + self._non_sequence_meta_values = { + k: [[] for _ in range(num_packing_bins)] + for k in self._non_sequence_meta_features + } + # Define the main buffers we will pack the data into. def make_packed_buffer(length: int, x: np.ndarray | int, padding: Any): is_scalar = np.ndim(x) == 0 @@ -143,8 +174,15 @@ def make_packed_buffer(length: int, x: np.ndarray | int, padding: Any): if padding_struct is None: padding_struct = tree_lib.map_structure(lambda x: None, length_struct) + pruned_element_for_shapes = tree_lib.map_structure_up_to( + length_struct, lambda x: x, element_for_shapes + ) + self._values = tree_lib.map_structure( - make_packed_buffer, length_struct, element_for_shapes, padding_struct + make_packed_buffer, + length_struct, + pruned_element_for_shapes, + padding_struct, ) def make_packed_aux_info(length: int): @@ -219,8 +257,16 @@ def get_packed_batch(self): self._positions, ) + values_copy = copy.copy(values) + for k in self._non_sequence_meta_features: + sliced_list = self._non_sequence_meta_values[k][:rows_with_values] + arr = np.empty(len(sliced_list), dtype=object) + for i, val in enumerate(sliced_list): + arr[i] = val + values_copy[k] = arr + return _extract_and_rekey_packed_batch( - values, + values_copy, segment_ids=segment_ids, positions=positions, meta_features=self._meta_features, @@ -237,7 +283,10 @@ def _get_element_lengths_flat(self, element: Any) -> np.ndarray: def add_element_to_batch(self, element: Any, row: int) -> None: """Adds an element to the specified row using pre-flattened buffers.""" - flat_element = tree_lib.flatten(element) + packable_element = tree_lib.map_structure_up_to( + self._length_struct, lambda x: x, element + ) + flat_element = tree_lib.flatten(packable_element) flat_alignments = tree_lib.flatten(self._pack_alignments) segment_id = self._num_examples_per_row[row] + 1 @@ -257,6 +306,10 @@ def add_element_to_batch(self, element: Any, row: int) -> None: self._flat_first_free_cell_per_row[idx][row] = padded_end self._num_examples_per_row[row] += 1 + for k in self._non_sequence_meta_features: + if k in element: + self._non_sequence_meta_values[k][row].append(element[k]) + @abc.abstractmethod def try_add_to_batch(self, element: Any) -> list[str] | None: """Tries to add an element to the batch using a specific strategy.""" @@ -267,8 +320,11 @@ class FirstFitPackedBatch(PackedBatch[_T]): """Implements first-fit packing of sequences.""" def try_add_to_batch(self, element: Any) -> list[str] | None: - tree_lib.assert_same_structure(element, self._length_struct) - element_lengths = self._get_element_lengths_flat(element) + packable_element = tree_lib.map_structure_up_to( + self._length_struct, lambda x: x, element + ) + tree_lib.assert_same_structure(packable_element, self._length_struct) + element_lengths = self._get_element_lengths_flat(packable_element) # Check if any feature exceeds its max length before attempting to pack. too_long = element_lengths > self._capacities @@ -336,8 +392,11 @@ class BestFitPackedBatch(PackedBatch[_T]): """Implements best-fit packing of sequences.""" def try_add_to_batch(self, element: Any) -> list[str] | None: - tree_lib.assert_same_structure(element, self._length_struct) - element_lengths = self._get_element_lengths_flat(element) + packable_element = tree_lib.map_structure_up_to( + self._length_struct, lambda x: x, element + ) + tree_lib.assert_same_structure(packable_element, self._length_struct) + element_lengths = self._get_element_lengths_flat(packable_element) # Check if any feature exceeds its max length before attempting to pack. too_long = element_lengths > self._capacities diff --git a/grain/_src/python/dataset/transformations/testing_util.py b/grain/_src/python/dataset/transformations/testing_util.py index 0c55cd0f1..23c5a9e8e 100644 --- a/grain/_src/python/dataset/transformations/testing_util.py +++ b/grain/_src/python/dataset/transformations/testing_util.py @@ -9,6 +9,7 @@ from absl.testing import parameterized from grain._src.core import tree_lib from grain._src.python.dataset.transformations import packing +from grain._src.python.dataset.transformations import packing_packed_batch from grain._src.python.dataset.transformations import source from jax import numpy as jnp import numpy as np @@ -50,8 +51,18 @@ def _common_test_body( input_elements = [ {k: np.asarray(v) for k, v in d.items()} for d in input_elements ] + non_sequence_meta_features = ( + packing_packed_batch.get_non_sequence_meta_features( + length_struct, meta_features + ) + ) + expected_elements = [ - {k: np.asarray(v) for k, v in d.items()} for d in expected_elements + { + k: v if k in non_sequence_meta_features else np.asarray(v) + for k, v in d.items() + } + for d in expected_elements ] ds = packer_cls( source.SourceMapDataset(input_elements).to_iter_dataset(), @@ -1487,6 +1498,102 @@ def test_pack_sequences_with_zeros(self): max_sequences_per_bin=3, ) + @parameterized.product( + convert_input_to_np=[True, False], + ) + def test_non_sequence_meta_feature(self, convert_input_to_np: bool): + input_elements = [ + { + "inputs": np.asarray([1]), + "targets": np.asarray([10]), + "image": np.ones((2, 2, 3), dtype=np.int32) * 1, + }, + { + "inputs": np.asarray([2]), + "targets": np.asarray([20]), + "image": np.ones((2, 2, 3), dtype=np.int32) * 2, + }, + { + "inputs": np.asarray([3]), + "targets": np.asarray([30]), + "image": np.ones((2, 2, 3), dtype=np.int32) * 3, + }, + ] + length_struct = {"inputs": 3, "targets": 3} + + expected_elements = [{ + "inputs": [1, 2, 3], + "targets": [10, 20, 30], + "inputs_segment_ids": [1, 2, 3], + "targets_segment_ids": [1, 2, 3], + "inputs_positions": [0, 0, 0], + "targets_positions": [0, 0, 0], + "image": [ + np.ones((2, 2, 3), dtype=np.int32) * 1, + np.ones((2, 2, 3), dtype=np.int32) * 2, + np.ones((2, 2, 3), dtype=np.int32) * 3, + ], + }] + _common_test_body( + self.packer_cls, + input_elements, + expected_elements, + length_struct, + kwargs=self.kwargs, + num_packing_bins=1, + meta_features=["image"], + convert_input_to_np=convert_input_to_np, + ) + + @parameterized.product( + convert_input_to_np=[True, False], + ) + def test_non_sequence_meta_feature_variable_shapes( + self, convert_input_to_np: bool + ): + input_elements = [ + { + "inputs": np.asarray([1]), + "targets": np.asarray([10]), + "image": np.ones((2, 1, 3, 3), dtype=np.int32) * 1, + }, + { + "inputs": np.asarray([2]), + "targets": np.asarray([20]), + "image": np.ones((1, 1, 3, 3), dtype=np.int32) * 2, + }, + { + "inputs": np.asarray([3]), + "targets": np.asarray([30]), + "image": np.ones((3, 1, 3, 3), dtype=np.int32) * 3, + }, + ] + length_struct = {"inputs": 3, "targets": 3} + + expected_elements = [{ + "inputs": [1, 2, 3], + "targets": [10, 20, 30], + "inputs_segment_ids": [1, 2, 3], + "targets_segment_ids": [1, 2, 3], + "inputs_positions": [0, 0, 0], + "targets_positions": [0, 0, 0], + "image": [ + np.ones((2, 1, 3, 3), dtype=np.int32) * 1, + np.ones((1, 1, 3, 3), dtype=np.int32) * 2, + np.ones((3, 1, 3, 3), dtype=np.int32) * 3, + ], + }] + _common_test_body( + self.packer_cls, + input_elements, + expected_elements, + length_struct, + kwargs=self.kwargs, + num_packing_bins=1, + meta_features=["image"], + convert_input_to_np=convert_input_to_np, + ) + class BaseBestFitPackIterDatasetTest(BaseFirstFitPackIterDatasetTest): """Base test for the Best-Fit packing algorithm.