Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
* Adds profiling of multiprocess workers when using XProf profiler. To enable,
set flag `grain_enable_multiprocess_worker_profiling=true` and add
`"profile_subprocesses" = True` in advanced profiler options.
* Adds support for non-sequence meta-features in packing transformations.

* Breaking changes:

Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ py_library(
srcs_version = "PY3",
deps = [
":packing",
":packing_packed_batch",
"//grain/_src/core:tree_lib",
"//grain/_src/python/dataset",
"@abseil-py//absl/testing:absltest",
Expand Down
26 changes: 22 additions & 4 deletions grain/_src/python/dataset/transformations/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Implements packing transformations."""

from collections.abc import Sequence
import copy
from typing import Any, Type

from grain._src.core import tree_lib
Expand Down Expand Up @@ -75,6 +76,11 @@ def __init__(
self._pack_alignment_struct = pack_alignment_struct
self._padding_struct = padding_struct
self._max_sequences_per_bin = max_sequences_per_bin
self._combined_struct = copy.copy(length_struct)
if isinstance(self._combined_struct, dict):
for k in meta_features:
if k not in self._combined_struct:
self._combined_struct[k] = None
self._reset()

def _reset(self):
Expand Down Expand Up @@ -217,7 +223,7 @@ def __next__(self):
with timer:
# Remove elements not in packing struct.
element = tree_lib.map_structure_up_to(
self._length_struct, lambda x: x, element
self._combined_struct, lambda x: x, element
)

if self._current_batch is None:
Expand Down Expand Up @@ -287,7 +293,11 @@ def __init__(
seed: Random seed for shuffling bins.
shuffle_bins: Whether to shuffle bins after packing.
shuffle_bins_group_by_feature: Feature to group by for shuffling.
meta_features: Meta features that do not need packing logic.
meta_features: Meta features that do not need packing logic. They can be
sequence meta-features (if present in `length_struct`, packed and padded
to target length) or non-sequence meta-features (if not present in
`length_struct`, returned as a list of meta-features from the packed
examples in each bin).
pack_alignment_struct: Optional per-feature alignment values.
padding_struct: Optional per-feature padding values.
max_sequences_per_bin: Optional maximum number of input sequences that can
Expand Down Expand Up @@ -367,7 +377,11 @@ def __init__(
seed: Random seed for shuffling bins.
shuffle_bins: Whether to shuffle bins after packing.
shuffle_bins_group_by_feature: Feature to group by for shuffling.
meta_features: Meta features that do not need packing logic.
meta_features: Meta features that do not need packing logic. They can be
sequence meta-features (if present in `length_struct`, packed and padded
to target length) or non-sequence meta-features (if not present in
`length_struct`, returned as a list of meta-features from the packed
examples in each bin).
pack_alignment_struct: Optional per-feature alignment values.
padding_struct: Optional per-feature padding values.
max_sequences_per_bin: Optional maximum number of input sequences that can
Expand Down Expand Up @@ -424,7 +438,11 @@ def __init__(
seed: Random seed for shuffling bins.
shuffle_bins: Whether to shuffle bins after packing.
shuffle_bins_group_by_feature: Feature to group by for shuffling.
meta_features: Meta features that do not need packing logic.
meta_features: Meta features that do not need packing logic. They can be
sequence meta-features (if present in `length_struct`, packed and padded
to target length) or non-sequence meta-features (if not present in
`length_struct`, returned as a list of meta-features from the packed
examples in each bin).
pack_alignment_struct: Optional per-feature alignment values.
padding_struct: Optional per-feature padding values.
max_sequences_per_bin: Optional maximum number of input sequences that can
Expand Down
73 changes: 66 additions & 7 deletions grain/_src/python/dataset/transformations/packing_packed_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,29 @@
_T = TypeVar("_T")


def get_key_name(path_element: Any) -> str:
if hasattr(path_element, "key"):
return path_element.key
if hasattr(path_element, "name"):
return path_element.name
return path_element


def get_length_struct_keys(length_struct: Any) -> set[str]:
return {
get_key_name(p[0])
for (p, _) in tree_lib.flatten_with_path(length_struct)
if p
}


def get_non_sequence_meta_features(
length_struct: Any, meta_features: Sequence[str]
) -> list[str]:
keys = get_length_struct_keys(length_struct)
return [k for k in meta_features if k not in keys]


@dataclasses.dataclass(frozen=True, kw_only=True)
class SuccessfulRowOrFailingComponents:
# Holds the index of the row to put a new element into if it can fit,
Expand Down Expand Up @@ -118,6 +141,14 @@ def __init__(
self._size_bytes = 0
self._max_sequences_per_bin = max_sequences_per_bin

self._non_sequence_meta_features = get_non_sequence_meta_features(
length_struct, meta_features
)
self._non_sequence_meta_values = {
k: [[] for _ in range(num_packing_bins)]
for k in self._non_sequence_meta_features
}

# Define the main buffers we will pack the data into.
def make_packed_buffer(length: int, x: np.ndarray | int, padding: Any):
is_scalar = np.ndim(x) == 0
Expand All @@ -143,8 +174,15 @@ def make_packed_buffer(length: int, x: np.ndarray | int, padding: Any):
if padding_struct is None:
padding_struct = tree_lib.map_structure(lambda x: None, length_struct)

pruned_element_for_shapes = tree_lib.map_structure_up_to(
length_struct, lambda x: x, element_for_shapes
)

self._values = tree_lib.map_structure(
make_packed_buffer, length_struct, element_for_shapes, padding_struct
make_packed_buffer,
length_struct,
pruned_element_for_shapes,
padding_struct,
)

def make_packed_aux_info(length: int):
Expand Down Expand Up @@ -219,8 +257,16 @@ def get_packed_batch(self):
self._positions,
)

values_copy = copy.copy(values)
for k in self._non_sequence_meta_features:
sliced_list = self._non_sequence_meta_values[k][:rows_with_values]
arr = np.empty(len(sliced_list), dtype=object)
for i, val in enumerate(sliced_list):
arr[i] = val
values_copy[k] = arr

return _extract_and_rekey_packed_batch(
values,
values_copy,
segment_ids=segment_ids,
positions=positions,
meta_features=self._meta_features,
Expand All @@ -237,7 +283,10 @@ def _get_element_lengths_flat(self, element: Any) -> np.ndarray:

def add_element_to_batch(self, element: Any, row: int) -> None:
"""Adds an element to the specified row using pre-flattened buffers."""
flat_element = tree_lib.flatten(element)
packable_element = tree_lib.map_structure_up_to(
self._length_struct, lambda x: x, element
)
flat_element = tree_lib.flatten(packable_element)
flat_alignments = tree_lib.flatten(self._pack_alignments)
segment_id = self._num_examples_per_row[row] + 1

Expand All @@ -257,6 +306,10 @@ def add_element_to_batch(self, element: Any, row: int) -> None:
self._flat_first_free_cell_per_row[idx][row] = padded_end
self._num_examples_per_row[row] += 1

for k in self._non_sequence_meta_features:
if k in element:
self._non_sequence_meta_values[k][row].append(element[k])

@abc.abstractmethod
def try_add_to_batch(self, element: Any) -> list[str] | None:
"""Tries to add an element to the batch using a specific strategy."""
Expand All @@ -267,8 +320,11 @@ class FirstFitPackedBatch(PackedBatch[_T]):
"""Implements first-fit packing of sequences."""

def try_add_to_batch(self, element: Any) -> list[str] | None:
tree_lib.assert_same_structure(element, self._length_struct)
element_lengths = self._get_element_lengths_flat(element)
packable_element = tree_lib.map_structure_up_to(
self._length_struct, lambda x: x, element
)
tree_lib.assert_same_structure(packable_element, self._length_struct)
element_lengths = self._get_element_lengths_flat(packable_element)

# Check if any feature exceeds its max length before attempting to pack.
too_long = element_lengths > self._capacities
Expand Down Expand Up @@ -336,8 +392,11 @@ class BestFitPackedBatch(PackedBatch[_T]):
"""Implements best-fit packing of sequences."""

def try_add_to_batch(self, element: Any) -> list[str] | None:
tree_lib.assert_same_structure(element, self._length_struct)
element_lengths = self._get_element_lengths_flat(element)
packable_element = tree_lib.map_structure_up_to(
self._length_struct, lambda x: x, element
)
tree_lib.assert_same_structure(packable_element, self._length_struct)
element_lengths = self._get_element_lengths_flat(packable_element)

# Check if any feature exceeds its max length before attempting to pack.
too_long = element_lengths > self._capacities
Expand Down
109 changes: 108 additions & 1 deletion grain/_src/python/dataset/transformations/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from absl.testing import parameterized
from grain._src.core import tree_lib
from grain._src.python.dataset.transformations import packing
from grain._src.python.dataset.transformations import packing_packed_batch
from grain._src.python.dataset.transformations import source
from jax import numpy as jnp
import numpy as np
Expand Down Expand Up @@ -50,8 +51,18 @@ def _common_test_body(
input_elements = [
{k: np.asarray(v) for k, v in d.items()} for d in input_elements
]
non_sequence_meta_features = (
packing_packed_batch.get_non_sequence_meta_features(
length_struct, meta_features
)
)

expected_elements = [
{k: np.asarray(v) for k, v in d.items()} for d in expected_elements
{
k: v if k in non_sequence_meta_features else np.asarray(v)
for k, v in d.items()
}
for d in expected_elements
]
ds = packer_cls(
source.SourceMapDataset(input_elements).to_iter_dataset(),
Expand Down Expand Up @@ -1487,6 +1498,102 @@ def test_pack_sequences_with_zeros(self):
max_sequences_per_bin=3,
)

@parameterized.product(
convert_input_to_np=[True, False],
)
def test_non_sequence_meta_feature(self, convert_input_to_np: bool):
input_elements = [
{
"inputs": np.asarray([1]),
"targets": np.asarray([10]),
"image": np.ones((2, 2, 3), dtype=np.int32) * 1,
},
{
"inputs": np.asarray([2]),
"targets": np.asarray([20]),
"image": np.ones((2, 2, 3), dtype=np.int32) * 2,
},
{
"inputs": np.asarray([3]),
"targets": np.asarray([30]),
"image": np.ones((2, 2, 3), dtype=np.int32) * 3,
},
]
length_struct = {"inputs": 3, "targets": 3}

expected_elements = [{
"inputs": [1, 2, 3],
"targets": [10, 20, 30],
"inputs_segment_ids": [1, 2, 3],
"targets_segment_ids": [1, 2, 3],
"inputs_positions": [0, 0, 0],
"targets_positions": [0, 0, 0],
"image": [
np.ones((2, 2, 3), dtype=np.int32) * 1,
np.ones((2, 2, 3), dtype=np.int32) * 2,
np.ones((2, 2, 3), dtype=np.int32) * 3,
],
}]
_common_test_body(
self.packer_cls,
input_elements,
expected_elements,
length_struct,
kwargs=self.kwargs,
num_packing_bins=1,
meta_features=["image"],
convert_input_to_np=convert_input_to_np,
)

@parameterized.product(
convert_input_to_np=[True, False],
)
def test_non_sequence_meta_feature_variable_shapes(
self, convert_input_to_np: bool
):
input_elements = [
{
"inputs": np.asarray([1]),
"targets": np.asarray([10]),
"image": np.ones((2, 1, 3, 3), dtype=np.int32) * 1,
},
{
"inputs": np.asarray([2]),
"targets": np.asarray([20]),
"image": np.ones((1, 1, 3, 3), dtype=np.int32) * 2,
},
{
"inputs": np.asarray([3]),
"targets": np.asarray([30]),
"image": np.ones((3, 1, 3, 3), dtype=np.int32) * 3,
},
]
length_struct = {"inputs": 3, "targets": 3}

expected_elements = [{
"inputs": [1, 2, 3],
"targets": [10, 20, 30],
"inputs_segment_ids": [1, 2, 3],
"targets_segment_ids": [1, 2, 3],
"inputs_positions": [0, 0, 0],
"targets_positions": [0, 0, 0],
"image": [
np.ones((2, 1, 3, 3), dtype=np.int32) * 1,
np.ones((1, 1, 3, 3), dtype=np.int32) * 2,
np.ones((3, 1, 3, 3), dtype=np.int32) * 3,
],
}]
_common_test_body(
self.packer_cls,
input_elements,
expected_elements,
length_struct,
kwargs=self.kwargs,
num_packing_bins=1,
meta_features=["image"],
convert_input_to_np=convert_input_to_np,
)


class BaseBestFitPackIterDatasetTest(BaseFirstFitPackIterDatasetTest):
"""Base test for the Best-Fit packing algorithm.
Expand Down
Loading