Skip to content
222 changes: 221 additions & 1 deletion src/braket/program_sets/program_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
from collections.abc import Iterator, Mapping, Sequence
from dataclasses import dataclass

from braket.ir.openqasm import ProgramSet as OpenQASMProgramSet

Expand Down Expand Up @@ -97,6 +98,167 @@ def total_shots(self) -> int:
raise ValueError("No per-executable shots defined")
return self._shots_per_executable * self.total_executables

def enumerate_executables(self) -> Iterator[tuple[int, int, int]]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm...should this be a property? Like enumeration?

"""Yield ``(binding_index, parameter_set_index, observable_index)`` tuples in order,
one per executable.

The iteration order is: iterate over ``self.entries``; within each entry,
iterate over parameter set indices; within each parameter set index,
iterate over observable indices. The total number of yields is ``self.total_executables``.

For ``Circuit``s and ``CircuitBinding``s with no input sets, ``parameter_set_index`` is 0.
For entries with no observables, ``observable_index`` is 0. For ``CircuitBinding``s with a
``Sum`` Hamiltonian, ``observable_index`` ranges over the summands.

This ordering is used by ``split`` to build its index map and by
``ProgramSetQuantumTaskResult.merge`` to merge results back into the original shape.

Yields:
tuple[int, int, int]: ``(binding_index, parameter_set_index, observable_index)``.
"""
for binding_idx, prog in enumerate(self._programs):
if isinstance(prog, Circuit):
yield binding_idx, 0, 0
continue
num_obs = len(prog.observables) if prog.observables is not None else 1
for ps_idx in range(len(prog.input_sets) if prog.input_sets is not None else 1):
for obs_idx in range(num_obs):
yield binding_idx, ps_idx, obs_idx

def split(self, max_executables: int) -> tuple[list[ProgramSet], list[list[int]]]:
"""Split this program set into program sets of at most ``max_executables`` executables,
alongside a map that records the position in the original program set of each executable
in each of the generated program sets.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would add, i.e. indexed from (0, total_executables-1)


When a single parameter set index of a ``CircuitBinding`` would by itself exceed
``max_executables`` due to its observable list or ``Sum`` Hamiltonian being larger than
the budget, the observable list is split into chunks of at most ``max_executables`` entries
(``Sum`` summands are sliced with coefficients preserved). Observable splitting is only
performed when necessary; otherwise the full observable list or ``Sum`` is kept intact.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have consequences for max splitting? What is the worst case example here? I.e. CB of only observables that are at the border of being too big for 2 to fit in one? I.e. 10 executables but only a 19 ps split? This approaches ~0.5 filling efficiency for large N.

I think a flatten option might be nice, just so the user knows it has been split 1-1, though it adds complexity to the merge as well. Is there a simpler way to approach/think about this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if it doesn't, that is okay - maybe we can just make it a bit more clear.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also seems to occur on the other side of the approach - 10 executables, but only a 9 split limit causes -> 0.5 filling efficiency.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also just occured to me we kind of have a flatten...bind_observables_to_inputs would do something like this - but there is no, "unbind", or reorder. It would be cool if merge could take that as an input as well though, as it should also have the same ordering.


Args:
max_executables (int): The maximum number of executables per program
set. Must be positive.

Returns:
tuple[list[ProgramSet], list[list[int]]]: ``(program_sets, index_map)``.
``index_map[k][j]`` is the index of the executable that the j-th executable of
``program_sets[k]`` represents.
If this program set already fits within ``max_executables``, the returned
program-set list is ``[self]`` and the index_map is ``[[0, 1, ...,
total_executables - 1]]``.

Raises:
ValueError: If ``max_executables`` is not positive.

Examples:
>>> ps = ProgramSet([
... CircuitBinding(c1, inputs1, obs1), # 100 param sets, 4 observables
... CircuitBinding(c2, inputs2, obs2), # 50 param sets, 2 observables
... ])
>>> subs, index_map = ps.split(120)
>>> [s.total_executables for s in subs]
[120, 120, 120, 120, 20]
>>> sum(len(m) for m in index_map) == ps.total_executables
True
"""
if max_executables <= 0:
raise ValueError(f"max_executables must be positive, got {max_executables}")

if self.total_executables <= max_executables:
return [self], [list(range(self.total_executables))]

program_sets = []
index_map = []
current = []
current_size = 0
for block in self._executable_blocks(max_executables):
if current and current_size + block.size > max_executables:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically the majority of the logic, no? If A + next can fit, use it. Else, continue.

sub, sub_map = self._build_program_set(current)
program_sets.append(sub)
index_map.append(sub_map)
current = []
current_size = 0
current.append(block)
current_size += block.size
sub, sub_map = self._build_program_set(current)
program_sets.append(sub)
index_map.append(sub_map)

return program_sets, index_map

def _executable_blocks(self, max_executables: int) -> list[_ExecutableBlock]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this definitely needs a description 😢

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather, maybe a detailed description of the splitting approach.

blocks = []
orig_idx = 0
for prog_idx, prog in enumerate(self._programs):
if isinstance(prog, Circuit):
blocks.append(
_ExecutableBlock(
prog_idx=prog_idx,
param_set_index=None,
obs_slice=None,
size=1,
original_indices=[orig_idx],
)
)
orig_idx += 1
continue

num_ps = len(prog.input_sets) if prog.input_sets is not None else 1
obs_windows = _observable_windows(
len(prog.observables) if prog.observables is not None else 1, max_executables
)
split_observables = len(obs_windows) > 1
for ps_idx in range(num_ps) if prog.input_sets is not None else [None]:
for start, stop in obs_windows:
size = stop - start
blocks.append(
_ExecutableBlock(
prog_idx=prog_idx,
param_set_index=ps_idx,
obs_slice=slice(start, stop) if split_observables else None,
size=size,
original_indices=list(range(orig_idx, orig_idx + size)),
)
)
orig_idx += size
return blocks

def _build_program_set(self, blocks: list[_ExecutableBlock]) -> tuple[ProgramSet, list[int]]:
entries = []
sub_map = []
i = 0
while i < len(blocks):
head = blocks[i]
prog = self._programs[head.prog_idx]
if head.param_set_index is None:
entries.append(_apply_obs_slice(prog, head.obs_slice))
sub_map.extend(head.original_indices)
i += 1
continue

j = i
while (
j + 1 < len(blocks)
and blocks[j + 1].prog_idx == head.prog_idx
and blocks[j + 1].obs_slice == blocks[j].obs_slice
and blocks[j + 1].param_set_index == blocks[j].param_set_index + 1
):
j += 1
start = head.param_set_index
stop = blocks[j].param_set_index + 1
entries.append(
CircuitBinding(
prog.circuit,
input_sets=prog.input_sets.as_list()[start:stop],
observables=_slice_observables(prog.observables, head.obs_slice),
)
)
for k in range(i, j + 1):
sub_map.extend(blocks[k].original_indices)
i = j + 1
return ProgramSet(entries, self._shots_per_executable), sub_map

@staticmethod
def zip(
circuits: Sequence[Circuit] | CircuitBinding,
Expand Down Expand Up @@ -206,6 +368,64 @@ def __repr__(self):
)


@dataclass
class _ExecutableBlock:
"""Multi-index range for an equivalence class of executables sharing the same combination of
``(circuit, observable list/Sum Hamiltonian, single parameter assignment)``.

Attributes:
prog_idx: Index of the originating program in ``ProgramSet.entries``.
param_set_index: Index into the originating ``CircuitBinding``'s ``input_sets``, or ``None``
for ``Circuit`` entries and ``CircuitBinding``s with no input sets.
obs_slice: Slice into the originating observable list or ``Sum`` summands when observables
were split to fit the budget; ``None`` means the full original observable list
(or no observables).
size: Number of executables this block represents (== ``len(original_indices)``).
original_indices: The indices of this block's executables
in the order of the original program set.
"""

prog_idx: int
param_set_index: int | None
obs_slice: slice | None
size: int
original_indices: list[int]


def _observable_windows(num_observables: int, max_executables: int) -> list[tuple[int, int]]:
if num_observables <= max_executables:
return [(0, num_observables)]
windows = []
start = 0
while start < num_observables:
stop = min(start + max_executables, num_observables)
windows.append((start, stop))
start = stop
return windows


def _slice_observables(
observables: Sum | Sequence[Observable] | None, obs_slice: slice | None
) -> Sum | Sequence[Observable] | None:
if obs_slice is None or observables is None:
return observables
if isinstance(observables, Sum):
return Sum(list(observables.summands)[obs_slice])
return list(observables)[obs_slice]


def _apply_obs_slice(
prog: CircuitBinding | Circuit, obs_slice: slice | None
) -> CircuitBinding | Circuit:
if obs_slice is None or isinstance(prog, Circuit) or prog.observables is None:
return prog
return CircuitBinding(
prog.circuit,
input_sets=prog.input_sets,
observables=_slice_observables(prog.observables, obs_slice),
)


def _zip_circuit_bindings(
circuit_binding: CircuitBinding,
input_sets: Sequence[Mapping[str, float]] | None,
Expand Down
Loading
Loading