-
Notifications
You must be signed in to change notification settings - Fork 191
feat: Split program sets #1249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Split program sets #1249
Changes from all commits
4e29c65
c55a6a2
cc80430
5a17aa7
0e2ad55
aba6722
746790d
cb55879
fe2cc3c
2774cdf
5495754
da10d72
dfc8920
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,8 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Mapping, Sequence | ||
| from collections.abc import Iterator, Mapping, Sequence | ||
| from dataclasses import dataclass | ||
|
|
||
| from braket.ir.openqasm import ProgramSet as OpenQASMProgramSet | ||
|
|
||
|
|
@@ -97,6 +98,167 @@ def total_shots(self) -> int: | |
| raise ValueError("No per-executable shots defined") | ||
| return self._shots_per_executable * self.total_executables | ||
|
|
||
| def enumerate_executables(self) -> Iterator[tuple[int, int, int]]: | ||
| """Yield ``(binding_index, parameter_set_index, observable_index)`` tuples in order, | ||
| one per executable. | ||
|
|
||
| The iteration order is: iterate over ``self.entries``; within each entry, | ||
| iterate over parameter set indices; within each parameter set index, | ||
| iterate over observable indices. The total number of yields is ``self.total_executables``. | ||
|
|
||
| For ``Circuit``s and ``CircuitBinding``s with no input sets, ``parameter_set_index`` is 0. | ||
| For entries with no observables, ``observable_index`` is 0. For ``CircuitBinding``s with a | ||
| ``Sum`` Hamiltonian, ``observable_index`` ranges over the summands. | ||
|
|
||
| This ordering is used by ``split`` to build its index map and by | ||
| ``ProgramSetQuantumTaskResult.merge`` to merge results back into the original shape. | ||
|
|
||
| Yields: | ||
| tuple[int, int, int]: ``(binding_index, parameter_set_index, observable_index)``. | ||
| """ | ||
| for binding_idx, prog in enumerate(self._programs): | ||
| if isinstance(prog, Circuit): | ||
| yield binding_idx, 0, 0 | ||
| continue | ||
| num_obs = len(prog.observables) if prog.observables is not None else 1 | ||
| for ps_idx in range(len(prog.input_sets) if prog.input_sets is not None else 1): | ||
| for obs_idx in range(num_obs): | ||
| yield binding_idx, ps_idx, obs_idx | ||
|
|
||
| def split(self, max_executables: int) -> tuple[list[ProgramSet], list[list[int]]]: | ||
| """Split this program set into program sets of at most ``max_executables`` executables, | ||
| alongside a map that records the position in the original program set of each executable | ||
| in each of the generated program sets. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would add, i.e. indexed from (0, total_executables-1) |
||
|
|
||
| When a single parameter set index of a ``CircuitBinding`` would by itself exceed | ||
| ``max_executables`` due to its observable list or ``Sum`` Hamiltonian being larger than | ||
| the budget, the observable list is split into chunks of at most ``max_executables`` entries | ||
| (``Sum`` summands are sliced with coefficients preserved). Observable splitting is only | ||
| performed when necessary; otherwise the full observable list or ``Sum`` is kept intact. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have consequences for max splitting? What is the worst case example here? I.e. CB of only observables that are at the border of being too big for 2 to fit in one? I.e. 10 executables but only a 19 ps split? This approaches ~0.5 filling efficiency for large N. I think a flatten option might be nice, just so the user knows it has been split 1-1, though it adds complexity to the merge as well. Is there a simpler way to approach/think about this?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, if it doesn't, that is okay - maybe we can just make it a bit more clear.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It also seems to occur on the other side of the approach - 10 executables, but only a 9 split limit causes -> 0.5 filling efficiency.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It also just occured to me we kind of have a flatten...bind_observables_to_inputs would do something like this - but there is no, "unbind", or reorder. It would be cool if merge could take that as an input as well though, as it should also have the same ordering. |
||
|
|
||
| Args: | ||
| max_executables (int): The maximum number of executables per program | ||
| set. Must be positive. | ||
|
|
||
| Returns: | ||
| tuple[list[ProgramSet], list[list[int]]]: ``(program_sets, index_map)``. | ||
| ``index_map[k][j]`` is the index of the executable that the j-th executable of | ||
| ``program_sets[k]`` represents. | ||
| If this program set already fits within ``max_executables``, the returned | ||
| program-set list is ``[self]`` and the index_map is ``[[0, 1, ..., | ||
| total_executables - 1]]``. | ||
|
|
||
| Raises: | ||
| ValueError: If ``max_executables`` is not positive. | ||
|
|
||
| Examples: | ||
| >>> ps = ProgramSet([ | ||
| ... CircuitBinding(c1, inputs1, obs1), # 100 param sets, 4 observables | ||
| ... CircuitBinding(c2, inputs2, obs2), # 50 param sets, 2 observables | ||
| ... ]) | ||
| >>> subs, index_map = ps.split(120) | ||
| >>> [s.total_executables for s in subs] | ||
| [120, 120, 120, 120, 20] | ||
| >>> sum(len(m) for m in index_map) == ps.total_executables | ||
| True | ||
| """ | ||
| if max_executables <= 0: | ||
| raise ValueError(f"max_executables must be positive, got {max_executables}") | ||
|
|
||
| if self.total_executables <= max_executables: | ||
| return [self], [list(range(self.total_executables))] | ||
|
|
||
| program_sets = [] | ||
| index_map = [] | ||
| current = [] | ||
| current_size = 0 | ||
| for block in self._executable_blocks(max_executables): | ||
| if current and current_size + block.size > max_executables: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is basically the majority of the logic, no? If A + next can fit, use it. Else, continue. |
||
| sub, sub_map = self._build_program_set(current) | ||
| program_sets.append(sub) | ||
| index_map.append(sub_map) | ||
| current = [] | ||
| current_size = 0 | ||
| current.append(block) | ||
| current_size += block.size | ||
| sub, sub_map = self._build_program_set(current) | ||
| program_sets.append(sub) | ||
| index_map.append(sub_map) | ||
|
|
||
| return program_sets, index_map | ||
|
|
||
| def _executable_blocks(self, max_executables: int) -> list[_ExecutableBlock]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, this definitely needs a description 😢
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather, maybe a detailed description of the splitting approach. |
||
| blocks = [] | ||
| orig_idx = 0 | ||
| for prog_idx, prog in enumerate(self._programs): | ||
| if isinstance(prog, Circuit): | ||
| blocks.append( | ||
| _ExecutableBlock( | ||
| prog_idx=prog_idx, | ||
| param_set_index=None, | ||
| obs_slice=None, | ||
| size=1, | ||
| original_indices=[orig_idx], | ||
| ) | ||
| ) | ||
| orig_idx += 1 | ||
| continue | ||
|
|
||
| num_ps = len(prog.input_sets) if prog.input_sets is not None else 1 | ||
| obs_windows = _observable_windows( | ||
| len(prog.observables) if prog.observables is not None else 1, max_executables | ||
| ) | ||
| split_observables = len(obs_windows) > 1 | ||
| for ps_idx in range(num_ps) if prog.input_sets is not None else [None]: | ||
| for start, stop in obs_windows: | ||
| size = stop - start | ||
| blocks.append( | ||
| _ExecutableBlock( | ||
| prog_idx=prog_idx, | ||
| param_set_index=ps_idx, | ||
| obs_slice=slice(start, stop) if split_observables else None, | ||
| size=size, | ||
| original_indices=list(range(orig_idx, orig_idx + size)), | ||
| ) | ||
| ) | ||
| orig_idx += size | ||
| return blocks | ||
|
|
||
| def _build_program_set(self, blocks: list[_ExecutableBlock]) -> tuple[ProgramSet, list[int]]: | ||
| entries = [] | ||
| sub_map = [] | ||
| i = 0 | ||
| while i < len(blocks): | ||
| head = blocks[i] | ||
| prog = self._programs[head.prog_idx] | ||
| if head.param_set_index is None: | ||
| entries.append(_apply_obs_slice(prog, head.obs_slice)) | ||
| sub_map.extend(head.original_indices) | ||
| i += 1 | ||
| continue | ||
|
|
||
| j = i | ||
| while ( | ||
| j + 1 < len(blocks) | ||
| and blocks[j + 1].prog_idx == head.prog_idx | ||
| and blocks[j + 1].obs_slice == blocks[j].obs_slice | ||
| and blocks[j + 1].param_set_index == blocks[j].param_set_index + 1 | ||
| ): | ||
| j += 1 | ||
| start = head.param_set_index | ||
| stop = blocks[j].param_set_index + 1 | ||
| entries.append( | ||
| CircuitBinding( | ||
| prog.circuit, | ||
| input_sets=prog.input_sets.as_list()[start:stop], | ||
| observables=_slice_observables(prog.observables, head.obs_slice), | ||
| ) | ||
| ) | ||
| for k in range(i, j + 1): | ||
| sub_map.extend(blocks[k].original_indices) | ||
| i = j + 1 | ||
| return ProgramSet(entries, self._shots_per_executable), sub_map | ||
|
|
||
| @staticmethod | ||
| def zip( | ||
| circuits: Sequence[Circuit] | CircuitBinding, | ||
|
|
@@ -206,6 +368,64 @@ def __repr__(self): | |
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class _ExecutableBlock: | ||
| """Multi-index range for an equivalence class of executables sharing the same combination of | ||
| ``(circuit, observable list/Sum Hamiltonian, single parameter assignment)``. | ||
|
|
||
| Attributes: | ||
| prog_idx: Index of the originating program in ``ProgramSet.entries``. | ||
| param_set_index: Index into the originating ``CircuitBinding``'s ``input_sets``, or ``None`` | ||
| for ``Circuit`` entries and ``CircuitBinding``s with no input sets. | ||
| obs_slice: Slice into the originating observable list or ``Sum`` summands when observables | ||
| were split to fit the budget; ``None`` means the full original observable list | ||
| (or no observables). | ||
| size: Number of executables this block represents (== ``len(original_indices)``). | ||
| original_indices: The indices of this block's executables | ||
| in the order of the original program set. | ||
| """ | ||
|
|
||
| prog_idx: int | ||
| param_set_index: int | None | ||
| obs_slice: slice | None | ||
| size: int | ||
| original_indices: list[int] | ||
|
|
||
|
|
||
| def _observable_windows(num_observables: int, max_executables: int) -> list[tuple[int, int]]: | ||
| if num_observables <= max_executables: | ||
| return [(0, num_observables)] | ||
| windows = [] | ||
| start = 0 | ||
| while start < num_observables: | ||
| stop = min(start + max_executables, num_observables) | ||
| windows.append((start, stop)) | ||
| start = stop | ||
| return windows | ||
|
|
||
|
|
||
| def _slice_observables( | ||
| observables: Sum | Sequence[Observable] | None, obs_slice: slice | None | ||
| ) -> Sum | Sequence[Observable] | None: | ||
| if obs_slice is None or observables is None: | ||
| return observables | ||
| if isinstance(observables, Sum): | ||
| return Sum(list(observables.summands)[obs_slice]) | ||
| return list(observables)[obs_slice] | ||
|
|
||
|
|
||
| def _apply_obs_slice( | ||
| prog: CircuitBinding | Circuit, obs_slice: slice | None | ||
| ) -> CircuitBinding | Circuit: | ||
| if obs_slice is None or isinstance(prog, Circuit) or prog.observables is None: | ||
| return prog | ||
| return CircuitBinding( | ||
| prog.circuit, | ||
| input_sets=prog.input_sets, | ||
| observables=_slice_observables(prog.observables, obs_slice), | ||
| ) | ||
|
|
||
|
|
||
| def _zip_circuit_bindings( | ||
| circuit_binding: CircuitBinding, | ||
| input_sets: Sequence[Mapping[str, float]] | None, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm...should this be a property? Like enumeration?