From 4e29c65a7b40fed666518ecfd65b7298524f882c Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Wed, 6 May 2026 13:27:01 -0700 Subject: [PATCH 1/9] feat: Method to split program set `ProgramSet.split` takes in a `max_excutables` parameter, so that the generated program sets will contain at most that many executables. Also fixed coverage output when running `tox`. --- src/braket/program_sets/program_set.py | 108 +++++++++++ .../braket/program_sets/test_program_set.py | 167 ++++++++++++++++++ tox.ini | 2 +- 3 files changed, 276 insertions(+), 1 deletion(-) diff --git a/src/braket/program_sets/program_set.py b/src/braket/program_sets/program_set.py index 0f4964af5..4663d09a3 100644 --- a/src/braket/program_sets/program_set.py +++ b/src/braket/program_sets/program_set.py @@ -97,6 +97,114 @@ def total_shots(self) -> int: raise ValueError("No per-executable shots defined") return self._shots_per_executable * self.total_executables + def split(self, max_executables: int) -> list[ProgramSet]: + """ + Split this program set into a list of program sets with + at most ``max_executables`` executables. + + Sum Hamiltonians and lists of observables will not be broken into separate program sets; + consequently, this method will fail if the size of any Hamiltonian or observable list + exceeds ``max_executables``. + + Adjacent triples originating from the same ``CircuitBinding`` are coalesced into + a single multi-parameter-set ``CircuitBinding`` in the resulting sub-program set. + + Concatenating the executables of the program sets in the list in order reproduces + the executables of the original program set. + + Args: + max_executables (int): The maximum number of executables allowed per + sub-program set. Must be positive. + + Returns: + list[ProgramSet]: The sub-program sets. If this program set already fits + within ``max_executables``, a single-element list containing ``self`` is + returned. + + Raises: + ValueError: If ``max_executables`` is not positive, or if a single triple + (one parameter-set index of a single ``CircuitBinding``) requires + more than ``max_executables`` executables, because its observable list or + ``Sum`` Hamiltonian is larger than allowed. + + Examples: + >>> ps = ProgramSet([ + ... CircuitBinding(c1, inputs1, obs1), # 100 param sets, 4 observables + ... CircuitBinding(c2, inputs2, obs2), # 50 param sets, 2 observables + ... ]) + >>> sub = ps.split(120) + >>> [s.total_executables for s in sub] + [120, 120, 120, 80, 60] + """ + if max_executables <= 0: + raise ValueError(f"max_executables must be positive, got {max_executables}") + + if self.total_executables <= max_executables: + return [self] + + program_sets = [] + current = [] + current_size = 0 + for triple in self._enumerate_triples(max_executables): + size = triple[2] + if current and current_size + size > max_executables: + program_sets.append(self._build_sub_program_set(current)) + current = [] + current_size = 0 + current.append(triple) + current_size += size + program_sets.append(self._build_sub_program_set(current)) + + return program_sets + + def _enumerate_triples(self, max_executables: int) -> list[tuple[int, int | None, int]]: + triples = [] + for prog_idx, prog in enumerate(self._programs): + if isinstance(prog, Circuit): + triples.append((prog_idx, None, 1)) + continue + obs = prog.observables + class_size = max(1, len(obs)) if obs is not None else 1 + if class_size > max_executables: + raise ValueError( + f"Program at index {prog_idx} has a single parameter-set index with " + f"{class_size} executables, exceeding max_executables={max_executables}" + ) + input_sets = prog.input_sets + if input_sets is None: + triples.append((prog_idx, None, class_size)) + else: + triples.extend((prog_idx, i, class_size) for i in range(len(input_sets))) + return triples + + def _build_sub_program_set(self, triples: list[tuple[int, int | None, int]]) -> ProgramSet: + entries = [] + i = 0 + while i < len(triples): + prog_idx, param_idx, _ = triples[i] + prog = self._programs[prog_idx] + if param_idx is None: + entries.append(prog) + i += 1 + continue + j = i + while ( + j + 1 < len(triples) + and triples[j + 1][0] == prog_idx + and triples[j + 1][1] == triples[j][1] + 1 + ): + j += 1 + start, stop = triples[i][1], triples[j][1] + 1 + entries.append( + CircuitBinding( + prog.circuit, + input_sets=prog.input_sets.as_list()[start:stop], + observables=prog.observables, + ) + ) + i = j + 1 + return ProgramSet(entries, self._shots_per_executable) + @staticmethod def zip( circuits: Sequence[Circuit] | CircuitBinding, diff --git a/test/unit_tests/braket/program_sets/test_program_set.py b/test/unit_tests/braket/program_sets/test_program_set.py index 2706007d4..f0cc04f90 100644 --- a/test/unit_tests/braket/program_sets/test_program_set.py +++ b/test/unit_tests/braket/program_sets/test_program_set.py @@ -534,3 +534,170 @@ def test_inequality(circuit_rx_parametrized): program_set = ProgramSet([binding, binding]) assert program_set != ProgramSet([binding, circuit_rx_parametrized]) assert program_set != circuit_rx_parametrized + + +def test_split_already_fits(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}]) + program_set = ProgramSet(binding) + sub = program_set.split(10) + assert sub == [program_set] + assert sub[0] is program_set + + +def test_split_exact_fit(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}]) + program_set = ProgramSet(binding) + sub = program_set.split(2) + assert sub == [program_set] + assert sub[0] is program_set + + +def test_split_plain_circuits(): + circs = [ghz(1), ghz(2), ghz(3), ghz(1), ghz(2)] + program_set = ProgramSet(circs, shots_per_executable=10) + sub = program_set.split(2) + assert [s.total_executables for s in sub] == [2, 2, 1] + assert sub[0].entries == circs[0:2] + assert sub[1].entries == circs[2:4] + assert sub[2].entries == circs[4:5] + + +def test_split_single_binding_packed(circuit_rx_parametrized): + inputs = {"theta": [float(i) for i in range(10)]} + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) + program_set = ProgramSet(binding) + sub = program_set.split(3) + assert [s.total_executables for s in sub] == [3, 3, 3, 1] + # Each sub-program-set is a single coalesced binding over a contiguous slice. + for s in sub: + assert len(s) == 1 + assert s.entries[0].circuit == circuit_rx_parametrized + assert s.entries[0].observables is None + thetas = [] + for s in sub: + thetas.extend(s.entries[0].input_sets.as_dict()["theta"]) + assert thetas == inputs["theta"] + + +def test_split_with_observables(circuit_rx_parametrized): + # 5 parameter-set indices, 4 observables => 5 classes of size 4. + inputs = {"theta": [float(i) for i in range(5)]} + observables = [X(0), Y(0), Z(0), X(0) @ Y(1)] + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) + program_set = ProgramSet(binding) + sub = program_set.split(8) + assert [s.total_executables for s in sub] == [8, 8, 4] + # Observables propagate unchanged (never split across sub-program-sets). + for s in sub: + assert s.entries[0].observables == observables + + +def test_split_with_sum_hamiltonian(circuit_rx_parametrized): + # Sum with 3 summands => class size = 3 per parameter-set index. + inputs = {"theta": [float(i) for i in range(4)]} + hamiltonian = 1.0 * X(0) + 2.0 * Y(0) + 3.0 * Z(0) + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=hamiltonian) + program_set = ProgramSet(binding) + sub = program_set.split(6) + assert [s.total_executables for s in sub] == [6, 6] + # Sum preserved intact. + for s in sub: + assert s.entries[0].observables is hamiltonian + + +def test_split_worked_example(circuit_rx_parametrized): + # Two bindings: c1 with 100 param sets × 4 obs, c2 with 50 param sets × 2 obs. + c1 = circuit_rx_parametrized + c2 = Circuit().rx(0, FreeParameter("phi")) + obs1 = [X(0), Y(0), Z(0), X(0) @ Y(1)] + obs2 = [X(0), Z(0)] + binding1 = CircuitBinding(c1, {"theta": [float(i) for i in range(100)]}, obs1) + binding2 = CircuitBinding(c2, {"phi": [float(i) for i in range(50)]}, obs2) + program_set = ProgramSet([binding1, binding2]) + + sub = program_set.split(120) + # Greedy packing fills each bucket up to the budget before flushing. + assert [s.total_executables for s in sub] == [120, 120, 120, 120, 20] + assert sum(s.total_executables for s in sub) == program_set.total_executables + # First three buckets are pure c1 (30 × 4 each). + for i in range(3): + assert len(sub[i]) == 1 + assert sub[i].entries[0].circuit == c1 + assert len(sub[i].entries[0].input_sets) == 30 + # Bucket 3 straddles both bindings (10 × 4 + 40 × 2 = 120); coalesced per binding. + assert len(sub[3]) == 2 + assert sub[3].entries[0].circuit == c1 + assert len(sub[3].entries[0].input_sets) == 10 + assert sub[3].entries[1].circuit == c2 + assert len(sub[3].entries[1].input_sets) == 40 + # Last bucket is pure c2 remainder (10 × 2 = 20). + assert len(sub[4]) == 1 + assert sub[4].entries[0].circuit == c2 + assert len(sub[4].entries[0].input_sets) == 10 + + +def test_split_preserves_shots(circuit_rx_parametrized): + inputs = {"theta": [float(i) for i in range(5)]} + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) + program_set = ProgramSet(binding, shots_per_executable=100) + sub = program_set.split(2) + assert all(s.shots_per_executable == 100 for s in sub) + assert sum(s.total_shots for s in sub) == program_set.total_shots + + +def test_split_coalesces_adjacent_same_binding(circuit_rx_parametrized): + # 6 parameter-set indices, class size 1, max_executables=4 => buckets of 4, 2. + # Each bucket should contain one coalesced multi-parameter-set binding, + # not four (resp. two) separate single-parameter-set bindings. + inputs = {"theta": [float(i) for i in range(6)]} + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) + program_set = ProgramSet(binding) + sub = program_set.split(4) + assert [len(s) for s in sub] == [1, 1] + assert len(sub[0].entries[0].input_sets) == 4 + assert len(sub[1].entries[0].input_sets) == 2 + + +def test_split_binding_without_input_sets(circuit_rx_parametrized): + # A binding with only observables is a single class of size len(observables). + c1 = circuit_rx_parametrized + c2 = Circuit().rx(0, FreeParameter("phi")) + binding_a = CircuitBinding(c1, observables=[X(0), Y(0)]) # size 2 + binding_b = CircuitBinding(c2, observables=[X(0), Y(0), Z(0)]) # size 3 + program_set = ProgramSet([binding_a, binding_b]) + sub = program_set.split(3) + assert [s.total_executables for s in sub] == [2, 3] + assert sub[0].entries == [binding_a] + assert sub[1].entries == [binding_b] + + +def test_split_non_positive_raises(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}]) + program_set = ProgramSet(binding) + with pytest.raises(ValueError, match="must be positive"): + program_set.split(0) + with pytest.raises(ValueError, match="must be positive"): + program_set.split(-3) + + +def test_split_oversize_class_raises(circuit_rx_parametrized): + # One parameter-set index with 3 observables exceeds max_executables=2. + inputs = {"theta": [1.0, 2.0]} + observables = [X(0), Y(0), Z(0)] + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) + program_set = ProgramSet(binding) + with pytest.raises(ValueError, match="exceeding max_executables"): + program_set.split(2) + + +def test_split_sub_program_sets_are_serializable(circuit_rx_parametrized): + inputs = {"theta": [float(i) for i in range(10)]} + observables = [X(0), Y(0)] + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) + program_set = ProgramSet(binding) + sub = program_set.split(6) + # Each sub-program set is a fully formed ProgramSet: to_ir() works and returns a + # single-program IR (one coalesced CircuitBinding per sub-program set here). + for s in sub: + ir = s.to_ir() + assert len(ir.programs) == len(s) diff --git a/tox.ini b/tox.ini index 6bb67ecd2..9e13e17ad 100644 --- a/tox.ini +++ b/tox.ini @@ -18,7 +18,7 @@ basepython = python3 deps = {[test-deps]deps} commands = - pytest {posargs} --cov=braket --cov-report term-missing --cov-report html --cov-report xml --cov-append + pytest {posargs} --cov --cov-report term-missing --cov-report html --cov-report xml --cov-append extras = test [testenv:integ-tests] From cc80430cd967c82cadc6464aefbb870cdbea58c7 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Fri, 8 May 2026 17:11:47 -0700 Subject: [PATCH 2/9] Add results combiner --- src/braket/program_sets/program_set.py | 238 +++++++++--- .../tasks/program_set_quantum_task_result.py | 209 +++++++++- .../braket/program_sets/test_program_set.py | 201 +++++++--- .../test_program_set_quantum_task_result.py | 366 ++++++++++++++++++ 4 files changed, 895 insertions(+), 119 deletions(-) diff --git a/src/braket/program_sets/program_set.py b/src/braket/program_sets/program_set.py index 4663d09a3..f3935b66e 100644 --- a/src/braket/program_sets/program_set.py +++ b/src/braket/program_sets/program_set.py @@ -13,7 +13,8 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +from collections.abc import Iterator, Mapping, Sequence +from dataclasses import dataclass from braket.ir.openqasm import ProgramSet as OpenQASMProgramSet @@ -97,113 +98,166 @@ def total_shots(self) -> int: raise ValueError("No per-executable shots defined") return self._shots_per_executable * self.total_executables - def split(self, max_executables: int) -> list[ProgramSet]: - """ - Split this program set into a list of program sets with - at most ``max_executables`` executables. + def enumerate_executables(self) -> Iterator[tuple[int, int, int]]: + """Yield ``(binding_index, parameter_set_index, observable_index)`` tuples in order, + one per executable. + + The iteration order is: iterate over ``self.entries``; within each entry, + iterate over parameter set indices; within each parameter set index, + iterate over observable indices. The total number of yields is ``self.total_executables``. - Sum Hamiltonians and lists of observables will not be broken into separate program sets; - consequently, this method will fail if the size of any Hamiltonian or observable list - exceeds ``max_executables``. + For ``Circuit``s and ``CircuitBinding``s with no input sets, ``parameter_set_index`` is 0. + For entries with no observables, ``observable_index`` is 0. For ``CircuitBinding``s with a + ``Sum`` Hamiltonian, ``observable_index`` ranges over the summands. - Adjacent triples originating from the same ``CircuitBinding`` are coalesced into - a single multi-parameter-set ``CircuitBinding`` in the resulting sub-program set. + This ordering is used by ``split`` to build its index map and by + ``ProgramSetQuantumTaskResult.from_multiple`` to merge results back into the original shape. - Concatenating the executables of the program sets in the list in order reproduces - the executables of the original program set. + Yields: + tuple[int, int, int]: ``(binding_index, parameter_set_index, observable_index)``. + """ + for b_idx, prog in enumerate(self._programs): + if isinstance(prog, Circuit): + yield b_idx, 0, 0 + continue + num_ps = len(prog.input_sets) if prog.input_sets is not None else 1 + num_obs = len(prog.observables) if prog.observables is not None else 1 + for ps_idx in range(num_ps): + for obs_idx in range(num_obs): + yield b_idx, ps_idx, obs_idx + + def split(self, max_executables: int) -> tuple[list[ProgramSet], list[list[int]]]: + """Split this program set into program sets of at most ``max_executables`` executables, + alongside a map that records the position in the original program set of each executable + in each of the generated program sets. + + When a single parameter set index of a ``CircuitBinding`` would by itself exceed + ``max_executables`` due to its observable list or ``Sum`` Hamiltonian being larger than + the budget, the observable list is split into chunks of at most ``max_executables`` entries + (``Sum`` summands are sliced with coefficients preserved). Observable splitting is only + performed when necessary; otherwise the full observable list or ``Sum`` is kept intact. Args: - max_executables (int): The maximum number of executables allowed per - sub-program set. Must be positive. + max_executables (int): The maximum number of executables per program + set. Must be positive. Returns: - list[ProgramSet]: The sub-program sets. If this program set already fits - within ``max_executables``, a single-element list containing ``self`` is - returned. + tuple[list[ProgramSet], list[list[int]]]: ``(program_sets, index_map)``. + ``index_map[k][j]`` is the index of the executable that the j-th executable of + ``program_sets[k]`` represents. + If this program set already fits within ``max_executables``, the returned + program-set list is ``[self]`` and the index_map is ``[[0, 1, ..., + total_executables - 1]]``. Raises: - ValueError: If ``max_executables`` is not positive, or if a single triple - (one parameter-set index of a single ``CircuitBinding``) requires - more than ``max_executables`` executables, because its observable list or - ``Sum`` Hamiltonian is larger than allowed. + ValueError: If ``max_executables`` is not positive. Examples: >>> ps = ProgramSet([ ... CircuitBinding(c1, inputs1, obs1), # 100 param sets, 4 observables ... CircuitBinding(c2, inputs2, obs2), # 50 param sets, 2 observables ... ]) - >>> sub = ps.split(120) - >>> [s.total_executables for s in sub] - [120, 120, 120, 80, 60] + >>> subs, index_map = ps.split(120) + >>> [s.total_executables for s in subs] + [120, 120, 120, 120, 20] + >>> sum(len(m) for m in index_map) == ps.total_executables + True """ if max_executables <= 0: raise ValueError(f"max_executables must be positive, got {max_executables}") if self.total_executables <= max_executables: - return [self] + return [self], [list(range(self.total_executables))] + triples = self._enumerate_triples(max_executables) program_sets = [] + index_map = [] current = [] current_size = 0 - for triple in self._enumerate_triples(max_executables): - size = triple[2] - if current and current_size + size > max_executables: - program_sets.append(self._build_sub_program_set(current)) + for cls in triples: + if current and current_size + cls.size > max_executables: + sub, sub_map = self._build_program_set(current) + program_sets.append(sub) + index_map.append(sub_map) current = [] current_size = 0 - current.append(triple) - current_size += size - program_sets.append(self._build_sub_program_set(current)) + current.append(cls) + current_size += cls.size + sub, sub_map = self._build_program_set(current) + program_sets.append(sub) + index_map.append(sub_map) - return program_sets + return program_sets, index_map - def _enumerate_triples(self, max_executables: int) -> list[tuple[int, int | None, int]]: + def _enumerate_triples(self, max_executables: int) -> list[_Triple]: triples = [] + orig_idx = 0 for prog_idx, prog in enumerate(self._programs): if isinstance(prog, Circuit): - triples.append((prog_idx, None, 1)) - continue - obs = prog.observables - class_size = max(1, len(obs)) if obs is not None else 1 - if class_size > max_executables: - raise ValueError( - f"Program at index {prog_idx} has a single parameter-set index with " - f"{class_size} executables, exceeding max_executables={max_executables}" + triples.append( + _Triple( + prog_idx=prog_idx, + param_set_index=None, + obs_slice=None, + size=1, + original_indices=[orig_idx], + ) ) - input_sets = prog.input_sets - if input_sets is None: - triples.append((prog_idx, None, class_size)) - else: - triples.extend((prog_idx, i, class_size) for i in range(len(input_sets))) + orig_idx += 1 + continue + + num_obs = len(prog.observables) if prog.observables is not None else 1 + num_ps = len(prog.input_sets) if prog.input_sets is not None else 1 + ps_indices = list(range(num_ps)) if prog.input_sets is not None else [None] + obs_windows = _observable_windows(num_obs, max_executables) + split_observables = len(obs_windows) > 1 + for ps_idx in ps_indices: + for start, stop in obs_windows: + size = stop - start + triples.append( + _Triple( + prog_idx=prog_idx, + param_set_index=ps_idx, + obs_slice=slice(start, stop) if split_observables else None, + size=size, + original_indices=list(range(orig_idx, orig_idx + size)), + ) + ) + orig_idx += size return triples - def _build_sub_program_set(self, triples: list[tuple[int, int | None, int]]) -> ProgramSet: - entries = [] + def _build_program_set(self, triples: list[_Triple]) -> tuple[ProgramSet, list[int]]: + entries: list[CircuitBinding | Circuit] = [] + sub_sub_map: list[int] = [] i = 0 while i < len(triples): - prog_idx, param_idx, _ = triples[i] - prog = self._programs[prog_idx] - if param_idx is None: - entries.append(prog) + head = triples[i] + prog = self._programs[head.prog_idx] + if head.param_set_index is None: + entries.append(_apply_obs_slice(prog, head.obs_slice)) + sub_sub_map.extend(head.original_indices) i += 1 continue + j = i while ( j + 1 < len(triples) - and triples[j + 1][0] == prog_idx - and triples[j + 1][1] == triples[j][1] + 1 + and triples[j + 1].prog_idx == head.prog_idx + and triples[j + 1].obs_slice == triples[j].obs_slice + and triples[j + 1].param_set_index == triples[j].param_set_index + 1 ): j += 1 - start, stop = triples[i][1], triples[j][1] + 1 - entries.append( - CircuitBinding( - prog.circuit, - input_sets=prog.input_sets.as_list()[start:stop], - observables=prog.observables, - ) + start, stop = head.param_set_index, triples[j].param_set_index + 1 + coalesced = CircuitBinding( + prog.circuit, + input_sets=prog.input_sets.as_list()[start:stop], + observables=_slice_observables(prog.observables, head.obs_slice), ) + entries.append(coalesced) + for k in range(i, j + 1): + sub_sub_map.extend(triples[k].original_indices) i = j + 1 - return ProgramSet(entries, self._shots_per_executable) + return ProgramSet(entries, self._shots_per_executable), sub_sub_map @staticmethod def zip( @@ -314,6 +368,64 @@ def __repr__(self): ) +@dataclass(frozen=True) +class _Triple: + """A contiguous run of executables sharing the same ``(circuit, observable list, + single parameter assignment)`` triple. + + Attributes: + prog_idx: Index of the originating program in ``ProgramSet.entries``. + param_set_index: Index into the originating ``CircuitBinding``'s ``input_sets``, or ``None`` + for ``Circuit`` entries and ``CircuitBinding``s with no input sets. + obs_slice: Slice into the originating observable list or ``Sum`` summands when observables + were split to fit the budget; ``None`` means the full original observable list + (or no observables). + size: Number of executables this triple represents (== ``len(original_indices)``). + original_indices: The indices of this triple's executables + in the order of the original program set. + """ + + prog_idx: int + param_set_index: int | None + obs_slice: slice | None + size: int + original_indices: list[int] + + +def _observable_windows(num_observables: int, max_executables: int) -> list[tuple[int, int]]: + if num_observables <= max_executables: + return [(0, num_observables)] + windows = [] + start = 0 + while start < num_observables: + stop = min(start + max_executables, num_observables) + windows.append((start, stop)) + start = stop + return windows + + +def _slice_observables( + observables: Sum | Sequence[Observable] | None, obs_slice: slice | None +) -> Sum | Sequence[Observable] | None: + if obs_slice is None or observables is None: + return observables + if isinstance(observables, Sum): + return Sum(list(observables.summands)[obs_slice]) + return list(observables)[obs_slice] + + +def _apply_obs_slice( + prog: CircuitBinding | Circuit, obs_slice: slice | None +) -> CircuitBinding | Circuit: + if obs_slice is None or isinstance(prog, Circuit) or prog.observables is None: + return prog + return CircuitBinding( + prog.circuit, + input_sets=prog.input_sets, + observables=_slice_observables(prog.observables, obs_slice), + ) + + def _zip_circuit_bindings( circuit_binding: CircuitBinding, input_sets: Sequence[Mapping[str, float]] | None, diff --git a/src/braket/tasks/program_set_quantum_task_result.py b/src/braket/tasks/program_set_quantum_task_result.py index c4c37b00e..02c1e6091 100644 --- a/src/braket/tasks/program_set_quantum_task_result.py +++ b/src/braket/tasks/program_set_quantum_task_result.py @@ -16,7 +16,7 @@ import warnings from collections import Counter from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, replace import boto3 import numpy as np @@ -31,10 +31,15 @@ ProgramSetTaskMetadata, ProgramSetTaskResult, ) +from braket.task_result.program_set_executable_result_v1 import ( + ProgramSetExecutableResultMetadata, +) +from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata -from braket.circuits import Observable +from braket.circuits import Circuit, Observable from braket.circuits.observable import EULER_OBSERVABLE_PREFIX from braket.circuits.observables import Sum +from braket.circuits.serialization import IRType from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.tasks.measurement_utils import ( expectation_from_measurements, @@ -370,6 +375,99 @@ def from_object( program_set=program_set, ) + @staticmethod + def from_multiple( + results: Sequence[ProgramSetQuantumTaskResult], + program_set: ProgramSet, + index_map: list[list[int]], + ) -> ProgramSetQuantumTaskResult: + """Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running + each program set of ``program_set.split(...)``. + + ``index_map`` is the per-executable map returned alongside the program sets by + ``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``, + of the executable that the jth executable of the kth task represents. The kth task's + executables are read in order for its program set, namely across ``results[k].entries``, + and within each ``CompositeEntry`` across its ``entries``. + + The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had + been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``, + and ``MeasuredEntry`` objects in the order of the program. + + Expectation values and ``Sum`` Hamiltonian expectations are computed + for the original ``ProgramSet``. + + Args: + results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same + order as ``program_set.split``'s return. + program_set (ProgramSet): The original unsplit program set. + index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``. + + Returns: + ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``. + + Raises: + ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map`` + doesn't match ``program_set.total_executables``, or if any task produces a + different number of executables than its map expects. + """ + if len(results) != len(index_map): + raise ValueError( + f"Got {len(results)} task results but {len(index_map)} entries in index_map" + ) + total_executables = program_set.total_executables + total_mapped = sum(len(m) for m in index_map) + if total_mapped != total_executables: + raise ValueError( + f"Index map covers {total_mapped} executables but the original program set " + f"has {total_executables}" + ) + + binding_programs = [_binding_to_program(binding) for binding in program_set.entries] + triples = list(program_set.enumerate_executables()) + binding_executable_counts = [_count_executables(b) for b in program_set.entries] + + metas = [r.task_metadata for r in results] + first_num_execs = _num_executables_from_metadata(metas[0]) + shots_per_executable = metas[0].requestedShots // first_num_execs if first_num_execs else 0 + + buffer = [None] * total_executables + for k, result in enumerate(results): + _buffer_result( + k=k, + result=result, + map_k=index_map[k], + program_set=program_set, + binding_programs=binding_programs, + triples=triples, + buffer=buffer, + ) + + entries = [] + start = 0 + for binding_idx, binding in enumerate(program_set.entries): + count = binding_executable_counts[binding_idx] + program = binding_programs[binding_idx] + observables = None if isinstance(binding, Circuit) else binding.observables + entries.append( + CompositeEntry( + entries=buffer[start : start + count], + program=program, + inputs=CompositeEntry._get_inputs(program, observables), + observables=observables, + shots_per_executable=shots_per_executable, + additional_metadata=None, + ) + ) + start += count + + return ProgramSetQuantumTaskResult( + entries=entries, + task_metadata=_aggregate_task_metadata(metas, program_set), + num_executables=total_executables, + program_set=program_set, + ) + def __len__(self): return len(self.entries) @@ -481,6 +579,113 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int: return counter +def _binding_to_program(binding: CircuitBinding | Circuit) -> Program: + if isinstance(binding, Circuit): + return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None) + return binding.to_ir() + + +def _count_executables(binding: CircuitBinding | Circuit) -> int: + if isinstance(binding, Circuit): + return 1 + num_ps = len(binding.input_sets) if binding.input_sets is not None else 1 + num_obs = len(binding.observables) if binding.observables is not None else 1 + return num_ps * num_obs + + +def _num_executables_from_metadata(metadata: ProgramSetTaskMetadata) -> int: + return sum(len(p.executables) for p in metadata.programMetadata) + + +def _buffer_result( + *, + k: int, + result: ProgramSetQuantumTaskResult, + map_k: list[int], + program_set: ProgramSet, + binding_programs: list[Program], + triples: list[tuple[int, int, int]], + buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None], +) -> None: + j = 0 + for composite in result.entries: + for entry in composite.entries: + if j >= len(map_k): + raise ValueError( + f"t=Task {result.task_metadata.id} at index {k} " + "produced more executables than index map expects" + ) + orig_idx = map_k[j] + b_idx, ps_idx, obs_idx = triples[orig_idx] + buffer[orig_idx] = _convert_measured_entry( + entry, + program_set.entries[b_idx], + binding_programs[b_idx], + ps_idx, + obs_idx, + ) + j += 1 + if j != len(map_k): + raise ValueError( + f"Task {result.task_metadata.id} at index {k} produced {j} executables " + f"but index map expected {len(map_k)}" + ) + + +def _convert_measured_entry( + entry: MeasuredEntry | ProgramSetExecutableFailure, + original_binding: CircuitBinding | Circuit, + original_program: Program, + parameter_set_index: int, + observable_index: int, +) -> MeasuredEntry | ProgramSetExecutableFailure: + if isinstance(entry, ProgramSetExecutableFailure): + return entry + if isinstance(original_binding, Circuit): + return replace(entry, program=original_program.source, inputs=None, observable=None) + observables = original_binding.observables + if observables is None: + observable: Observable | None = None + num_obs = 1 + elif isinstance(observables, Sum): + observable = observables.summands[observable_index] + num_obs = len(observables.summands) + else: + observable = observables[observable_index] + num_obs = len(observables) + orig_inputs_index = parameter_set_index * num_obs + observable_index + program_inputs = original_program.inputs or {} + inputs = {key: value[orig_inputs_index] for key, value in program_inputs.items()} or None + return replace(entry, program=original_program.source, inputs=inputs, observable=observable) + + +def _aggregate_task_metadata( + metas: Sequence[ProgramSetTaskMetadata], program_set: ProgramSet +) -> ProgramSetTaskMetadata: + first = metas[0] + created_values = [m.createdAt for m in metas if m.createdAt] + ended_values = [m.endedAt for m in metas if m.endedAt] + return ProgramSetTaskMetadata( + id=";".join(meta.id for meta in metas), + deviceId=first.deviceId, + requestedShots=sum(m.requestedShots for m in metas), + successfulShots=sum(m.successfulShots for m in metas), + programMetadata=[ + ProgramMetadata( + executables=[ + ProgramSetExecutableResultMetadata() for _ in range(_count_executables(b)) + ] + ) + for b in program_set.entries + ], + deviceParameters=None, + createdAt=min(created_values) if created_values else None, + endedAt=max(ended_values) if ended_values else None, + status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", + totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), + ) + + def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str: """Retrieve the S3 object body. diff --git a/test/unit_tests/braket/program_sets/test_program_set.py b/test/unit_tests/braket/program_sets/test_program_set.py index f0cc04f90..4a5b1792c 100644 --- a/test/unit_tests/braket/program_sets/test_program_set.py +++ b/test/unit_tests/braket/program_sets/test_program_set.py @@ -539,44 +539,48 @@ def test_inequality(circuit_rx_parametrized): def test_split_already_fits(circuit_rx_parametrized): binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}]) program_set = ProgramSet(binding) - sub = program_set.split(10) - assert sub == [program_set] - assert sub[0] is program_set + subs, mapping = program_set.split(10) + assert subs == [program_set] + assert subs[0] is program_set + assert mapping == [[0, 1]] def test_split_exact_fit(circuit_rx_parametrized): binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}]) program_set = ProgramSet(binding) - sub = program_set.split(2) - assert sub == [program_set] - assert sub[0] is program_set + subs, mapping = program_set.split(2) + assert subs == [program_set] + assert subs[0] is program_set + assert mapping == [[0, 1]] def test_split_plain_circuits(): circs = [ghz(1), ghz(2), ghz(3), ghz(1), ghz(2)] program_set = ProgramSet(circs, shots_per_executable=10) - sub = program_set.split(2) - assert [s.total_executables for s in sub] == [2, 2, 1] - assert sub[0].entries == circs[0:2] - assert sub[1].entries == circs[2:4] - assert sub[2].entries == circs[4:5] + subs, mapping = program_set.split(2) + assert [s.total_executables for s in subs] == [2, 2, 1] + assert subs[0].entries == circs[0:2] + assert subs[1].entries == circs[2:4] + assert subs[2].entries == circs[4:5] + assert mapping == [[0, 1], [2, 3], [4]] def test_split_single_binding_packed(circuit_rx_parametrized): inputs = {"theta": [float(i) for i in range(10)]} binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) program_set = ProgramSet(binding) - sub = program_set.split(3) - assert [s.total_executables for s in sub] == [3, 3, 3, 1] + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 3, 3, 1] # Each sub-program-set is a single coalesced binding over a contiguous slice. - for s in sub: + for s in subs: assert len(s) == 1 assert s.entries[0].circuit == circuit_rx_parametrized assert s.entries[0].observables is None thetas = [] - for s in sub: + for s in subs: thetas.extend(s.entries[0].input_sets.as_dict()["theta"]) assert thetas == inputs["theta"] + assert mapping == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] def test_split_with_observables(circuit_rx_parametrized): @@ -585,11 +589,12 @@ def test_split_with_observables(circuit_rx_parametrized): observables = [X(0), Y(0), Z(0), X(0) @ Y(1)] binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) program_set = ProgramSet(binding) - sub = program_set.split(8) - assert [s.total_executables for s in sub] == [8, 8, 4] + subs, mapping = program_set.split(8) + assert [s.total_executables for s in subs] == [8, 8, 4] # Observables propagate unchanged (never split across sub-program-sets). - for s in sub: + for s in subs: assert s.entries[0].observables == observables + assert sum(mapping, []) == list(range(20)) def test_split_with_sum_hamiltonian(circuit_rx_parametrized): @@ -598,11 +603,12 @@ def test_split_with_sum_hamiltonian(circuit_rx_parametrized): hamiltonian = 1.0 * X(0) + 2.0 * Y(0) + 3.0 * Z(0) binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=hamiltonian) program_set = ProgramSet(binding) - sub = program_set.split(6) - assert [s.total_executables for s in sub] == [6, 6] - # Sum preserved intact. - for s in sub: + subs, mapping = program_set.split(6) + assert [s.total_executables for s in subs] == [6, 6] + # Sum preserved intact (no observable-splitting needed at max=6). + for s in subs: assert s.entries[0].observables is hamiltonian + assert sum(mapping, []) == list(range(12)) def test_split_worked_example(circuit_rx_parametrized): @@ -615,34 +621,36 @@ def test_split_worked_example(circuit_rx_parametrized): binding2 = CircuitBinding(c2, {"phi": [float(i) for i in range(50)]}, obs2) program_set = ProgramSet([binding1, binding2]) - sub = program_set.split(120) + subs, mapping = program_set.split(120) # Greedy packing fills each bucket up to the budget before flushing. - assert [s.total_executables for s in sub] == [120, 120, 120, 120, 20] - assert sum(s.total_executables for s in sub) == program_set.total_executables + assert [s.total_executables for s in subs] == [120, 120, 120, 120, 20] + assert sum(s.total_executables for s in subs) == program_set.total_executables # First three buckets are pure c1 (30 × 4 each). for i in range(3): - assert len(sub[i]) == 1 - assert sub[i].entries[0].circuit == c1 - assert len(sub[i].entries[0].input_sets) == 30 + assert len(subs[i]) == 1 + assert subs[i].entries[0].circuit == c1 + assert len(subs[i].entries[0].input_sets) == 30 # Bucket 3 straddles both bindings (10 × 4 + 40 × 2 = 120); coalesced per binding. - assert len(sub[3]) == 2 - assert sub[3].entries[0].circuit == c1 - assert len(sub[3].entries[0].input_sets) == 10 - assert sub[3].entries[1].circuit == c2 - assert len(sub[3].entries[1].input_sets) == 40 + assert len(subs[3]) == 2 + assert subs[3].entries[0].circuit == c1 + assert len(subs[3].entries[0].input_sets) == 10 + assert subs[3].entries[1].circuit == c2 + assert len(subs[3].entries[1].input_sets) == 40 # Last bucket is pure c2 remainder (10 × 2 = 20). - assert len(sub[4]) == 1 - assert sub[4].entries[0].circuit == c2 - assert len(sub[4].entries[0].input_sets) == 10 + assert len(subs[4]) == 1 + assert subs[4].entries[0].circuit == c2 + assert len(subs[4].entries[0].input_sets) == 10 + # Mapping covers every original executable exactly once, in order. + assert sum(mapping, []) == list(range(500)) def test_split_preserves_shots(circuit_rx_parametrized): inputs = {"theta": [float(i) for i in range(5)]} binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) program_set = ProgramSet(binding, shots_per_executable=100) - sub = program_set.split(2) - assert all(s.shots_per_executable == 100 for s in sub) - assert sum(s.total_shots for s in sub) == program_set.total_shots + subs, _ = program_set.split(2) + assert all(s.shots_per_executable == 100 for s in subs) + assert sum(s.total_shots for s in subs) == program_set.total_shots def test_split_coalesces_adjacent_same_binding(circuit_rx_parametrized): @@ -652,10 +660,10 @@ def test_split_coalesces_adjacent_same_binding(circuit_rx_parametrized): inputs = {"theta": [float(i) for i in range(6)]} binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) program_set = ProgramSet(binding) - sub = program_set.split(4) - assert [len(s) for s in sub] == [1, 1] - assert len(sub[0].entries[0].input_sets) == 4 - assert len(sub[1].entries[0].input_sets) == 2 + subs, _ = program_set.split(4) + assert [len(s) for s in subs] == [1, 1] + assert len(subs[0].entries[0].input_sets) == 4 + assert len(subs[1].entries[0].input_sets) == 2 def test_split_binding_without_input_sets(circuit_rx_parametrized): @@ -665,10 +673,11 @@ def test_split_binding_without_input_sets(circuit_rx_parametrized): binding_a = CircuitBinding(c1, observables=[X(0), Y(0)]) # size 2 binding_b = CircuitBinding(c2, observables=[X(0), Y(0), Z(0)]) # size 3 program_set = ProgramSet([binding_a, binding_b]) - sub = program_set.split(3) - assert [s.total_executables for s in sub] == [2, 3] - assert sub[0].entries == [binding_a] - assert sub[1].entries == [binding_b] + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [2, 3] + assert subs[0].entries == [binding_a] + assert subs[1].entries == [binding_b] + assert mapping == [[0, 1], [2, 3, 4]] def test_split_non_positive_raises(circuit_rx_parametrized): @@ -680,14 +689,52 @@ def test_split_non_positive_raises(circuit_rx_parametrized): program_set.split(-3) -def test_split_oversize_class_raises(circuit_rx_parametrized): - # One parameter-set index with 3 observables exceeds max_executables=2. +def test_split_oversize_list_observables_are_chunked(circuit_rx_parametrized): + # A single class of 10 observables with max_executables=3 becomes 4 sub-program-sets + # of sizes 3, 3, 3, 1, each with a sliced observable list. + observables = [X(0), Y(0), Z(0), X(0), Y(0), Z(0), X(0), Y(0), Z(0), X(0)] + binding = CircuitBinding(circuit_rx_parametrized, observables=observables) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 3, 3, 1] + assert mapping == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + slices = [list(s.entries[0].observables) for s in subs] + assert slices == [observables[0:3], observables[3:6], observables[6:9], observables[9:10]] + + +def test_split_oversize_sum_hamiltonian_is_chunked(circuit_rx_parametrized): + # Sum with 7 summands, max_executables=3 → sub-Sums of sizes 3, 3, 1 with + # coefficients preserved on each summand. + ham = 1.0 * X(0) + 2.0 * Y(0) + 3.0 * Z(0) + 4.0 * X(0) + 5.0 * Y(0) + 6.0 * Z(0) + 7.0 * X(0) + binding = CircuitBinding(circuit_rx_parametrized, observables=ham) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 3, 1] + assert mapping == [[0, 1, 2], [3, 4, 5], [6]] + # Each sub-observable is a Sum whose summands come from the original in order. + expected_summands = list(ham.summands) + got_summands: list = [] + for s in subs: + sub_obs = s.entries[0].observables + assert isinstance(sub_obs, type(ham)) + got_summands.extend(sub_obs.summands) + assert got_summands == expected_summands + + +def test_split_oversize_observables_with_multiple_param_sets(circuit_rx_parametrized): + # 2 parameter sets x 5 observables, max_executables=3 ⇒ each parameter-set index + # splits into two observable windows ((0,3) size 3 and (3,5) size 2). The packer + # can't coalesce across parameter sets because they're interleaved by window, so we + # end up with 4 sub-program-sets. inputs = {"theta": [1.0, 2.0]} - observables = [X(0), Y(0), Z(0)] + observables = [X(0), Y(0), Z(0), X(0), Y(0)] binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) program_set = ProgramSet(binding) - with pytest.raises(ValueError, match="exceeding max_executables"): - program_set.split(2) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 2, 3, 2] + # Mapping follows canonical ordering: ps=0,obs=0..4 = indices 0..4; ps=1,obs=0..4 = 5..9. + assert mapping == [[0, 1, 2], [3, 4], [5, 6, 7], [8, 9]] + assert sum(mapping, []) == list(range(program_set.total_executables)) def test_split_sub_program_sets_are_serializable(circuit_rx_parametrized): @@ -695,9 +742,55 @@ def test_split_sub_program_sets_are_serializable(circuit_rx_parametrized): observables = [X(0), Y(0)] binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) program_set = ProgramSet(binding) - sub = program_set.split(6) + subs, _ = program_set.split(6) # Each sub-program set is a fully formed ProgramSet: to_ir() works and returns a # single-program IR (one coalesced CircuitBinding per sub-program set here). - for s in sub: + for s in subs: ir = s.to_ir() assert len(ir.programs) == len(s) + + +def test_enumerate_executables_plain_circuits(): + ps = ProgramSet([ghz(1), ghz(2), ghz(3)]) + assert list(ps.enumerate_executables()) == [(0, 0, 0), (1, 0, 0), (2, 0, 0)] + + +def test_enumerate_executables_binding_with_input_sets_only(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets={"theta": [0.1, 0.2, 0.3]}) + ps = ProgramSet(binding) + assert list(ps.enumerate_executables()) == [(0, 0, 0), (0, 1, 0), (0, 2, 0)] + + +def test_enumerate_executables_binding_with_observables_only(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, observables=[X(0), Y(0), Z(0)]) + ps = ProgramSet(binding) + assert list(ps.enumerate_executables()) == [(0, 0, 0), (0, 0, 1), (0, 0, 2)] + + +def test_enumerate_executables_mixed(): + # circuit, binding with 2 ps x 3 obs, binding with 2 ps no obs, binding with 4 obs no ps. + c0 = ghz(1) + c1 = Circuit().rx(0, FreeParameter("t")).cnot(0, 1) + c2 = Circuit().rx(0, FreeParameter("p")) + c3 = Circuit().h(0) + b1 = CircuitBinding(c1, {"t": [0.1, 0.2]}, [X(0), Y(0), Z(0)]) + b2 = CircuitBinding(c2, {"p": [0.3, 0.4]}) + b3 = CircuitBinding(c3, observables=[X(0), Y(0), Z(0), X(0) @ Y(1)]) + ps = ProgramSet([c0, b1, b2, b3]) + expected = [ + (0, 0, 0), + (1, 0, 0), + (1, 0, 1), + (1, 0, 2), + (1, 1, 0), + (1, 1, 1), + (1, 1, 2), + (2, 0, 0), + (2, 1, 0), + (3, 0, 0), + (3, 0, 1), + (3, 0, 2), + (3, 0, 3), + ] + assert list(ps.enumerate_executables()) == expected + assert len(expected) == ps.total_executables diff --git a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py index 3580c4eb0..0d8fac6b1 100644 --- a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py +++ b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py @@ -19,6 +19,8 @@ from braket.circuits import Circuit from braket.circuits.observables import X, Y, Z +from braket.circuits.serialization import IRType +from braket.ir.openqasm import Program from braket.parametric import FreeParameter from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.schema_common import BraketSchemaBase @@ -437,3 +439,367 @@ def test_dispatch_executable_result_with_none_inputs(execution_measurement_proba assert isinstance(measured_entry, MeasuredEntry) assert measured_entry.inputs is None assert measured_entry.probabilities == {"00": 0.7, "11": 0.3} + + +_SIM_METADATA_HEADER = { + "braketSchemaHeader": {"name": "braket.task_result.simulator_metadata", "version": "1"}, + "executionDuration": 50, +} +_DEVICE_PARAMS = { + "braketSchemaHeader": { + "name": "braket.device_schema.simulators.gate_model_simulator_device_parameters", + "version": "1", + }, + "paradigmParameters": { + "braketSchemaHeader": { + "name": "braket.device_schema.gate_model_parameters", + "version": "1", + }, + "qubitCount": 5, + "disableQubitRewiring": False, + }, +} + + +def _make_exec_result(inputs_index, probs=None): + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_executable_result", + "version": "1", + }, + "inputsIndex": inputs_index, + "measurementProbabilities": probs or {"00": 0.7, "11": 0.3}, + "measuredQubits": [0, 1], + } + + +def _make_program_result(program_dict, executable_dicts): + return { + "braketSchemaHeader": {"name": "braket.task_result.program_result", "version": "1"}, + "executableResults": executable_dicts, + "source": program_dict, + "additionalMetadata": {"simulatorMetadata": dict(_SIM_METADATA_HEADER)}, + } + + +def _make_task_metadata( + program_executable_counts, task_id="arn:aws:braket:::task/sub", shots_per_executable=40 +): + total = sum(program_executable_counts) + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_task_metadata", + "version": "1", + }, + "id": task_id, + "deviceId": "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + "requestedShots": shots_per_executable * total, + "successfulShots": shots_per_executable * total, + "programMetadata": [ + {"executables": [{} for _ in range(n)]} for n in program_executable_counts + ], + "deviceParameters": dict(_DEVICE_PARAMS), + "createdAt": "2024-10-15T19:06:58.986Z", + "endedAt": "2024-10-15T19:07:00.382Z", + "status": "COMPLETED", + "totalFailedExecutables": 0, + } + + +def _make_task_result(program_results, metadata): + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_task_result", + "version": "1", + }, + "programResults": program_results, + "taskMetadata": metadata, + } + + +def _parse(d): + return BraketSchemaBase.parse_raw_schema(json.dumps(d)) + + +def _build_sub_quantum_result(sub_program_set, programs_execs, shots_per_executable=40): + """Build a :class:`ProgramSetQuantumTaskResult` for a sub-program-set by first + building a wire-format ``ProgramSetTaskResult`` and passing it through + :meth:`ProgramSetQuantumTaskResult.from_object`. + + Args: + sub_program_set: The sub-``ProgramSet`` whose run produced the result. + programs_execs: One list of exec-result dicts per entry in ``sub_program_set.entries``. + shots_per_executable: shots per executable, propagated to the metadata. + """ + program_results = [] + counts = [] + for entry, execs in zip(sub_program_set.entries, programs_execs, strict=True): + if isinstance(entry, CircuitBinding): + source_dict = entry.to_ir().dict() + else: + source_dict = Program(source=entry.to_ir(IRType.OPENQASM).source, inputs=None).dict() + program_results.append(_make_program_result(source_dict, execs)) + counts.append(len(execs)) + wire = _parse( + _make_task_result( + program_results, _make_task_metadata(counts, shots_per_executable=shots_per_executable) + ) + ) + return ProgramSetQuantumTaskResult.from_object(wire, sub_program_set) + + +def test_from_multiple_single_sub_task_no_split_roundtrips(circuit_rx_parametrized_fixture): + """If split returns [self], from_multiple should reproduce from_object's output.""" + binding = CircuitBinding( + circuit_rx_parametrized_fixture, + input_sets={"theta": [0.12, 2.1]}, + observables=10 * Z(0) + X(0) - 0.01 * Y(0) @ X(1), + ) + ps = ProgramSet(binding) + subs, mapping = ps.split(100) # fits, so one sub-task identical to ps. + assert subs == [ps] + + # Build a ProgramSetQuantumTaskResult that represents running this ps: the wire + # payload goes through from_object first. + sub_program = subs[0].to_ir().programs[0].dict() + execs = [_make_exec_result(i) for i in range(ps.total_executables)] + wire = _parse( + _make_task_result( + [_make_program_result(sub_program, execs)], + _make_task_metadata([ps.total_executables]), + ) + ) + reference = ProgramSetQuantumTaskResult.from_object(wire, ps) + + merged = ProgramSetQuantumTaskResult.from_multiple([reference], ps, mapping) + + assert len(merged) == len(reference) == 1 + ref_composite = reference[0] + got_composite = merged[0] + assert len(got_composite) == len(ref_composite) + assert got_composite.program == ref_composite.program + assert got_composite.inputs == ref_composite.inputs + assert got_composite.observables == ref_composite.observables + for m_got, m_ref in zip(got_composite.entries, ref_composite.entries): + assert m_got.measured_qubits == m_ref.measured_qubits + assert m_got.probabilities == m_ref.probabilities + assert m_got.observable == m_ref.observable + assert m_got.inputs == m_ref.inputs + + +def test_from_multiple_split_list_observables(circuit_rx_parametrized_fixture): + """Split a binding with more observables than fit; scatter + regroup must + reconstruct the same CompositeEntry as running unsplit.""" + binding = CircuitBinding( + circuit_rx_parametrized_fixture, + input_sets={"theta": [0.12]}, + observables=[X(0), Y(0), Z(0), X(0) @ Y(1)], # 4 observables. + ) + ps = ProgramSet(binding) + subs, mapping = ps.split(2) # 4 > 2, so observables split into windows (0,2), (2,4). + assert [s.total_executables for s in subs] == [2, 2] + + # One sub-quantum-result per sub-program-set, built by running each through + # from_object on an inline wire payload. + sub_results = [ + _build_sub_quantum_result( + sub, [[_make_exec_result(i, {"00": 1.0}) for i in range(sub.total_executables)]] + ) + for sub in subs + ] + + merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + assert len(merged) == 1 + composite = merged[0] + # The merged composite should have 4 MeasuredEntries in canonical order, each with + # the ORIGINAL binding's observable attached at that index. + assert len(composite) == 4 + for i, measured in enumerate(composite.entries): + assert isinstance(measured, MeasuredEntry) + assert measured.observable == binding.observables[i] + assert composite.inputs == ParameterSets({"theta": [0.12]}) + # task metadata was aggregated across sub-tasks. + assert merged.num_executables == 4 + assert merged.task_metadata.requestedShots == sum( + r.task_metadata.requestedShots for r in sub_results + ) + assert merged.task_metadata.successfulShots == sum( + r.task_metadata.successfulShots for r in sub_results + ) + + +def test_from_multiple_split_sum_hamiltonian_reconstructs_expectation( + circuit_rx_parametrized_fixture, +): + """Splitting a Sum Hamiltonian across multiple sub-tasks and then merging must + reconstruct the full expectation value, because scatter+regroup feeds the original + Sum back into ``_compute_expectations``.""" + # Same fixture as existing test_observables_no_inputs (with known expectation). + circuit = Circuit().h(0).cnot(0, 1) + h = 10000 * Z(0) + 1000 * X(0) - 100 * Z(0) + 10 * Z(1) + X(1) - 0.1 * Y(1) + binding = CircuitBinding(circuit, observables=h) + ps = ProgramSet(binding) + assert ps.total_executables == 6 + + subs, mapping = ps.split(2) # 6 > 2, so Sum splits into 3 windows of size 2. + assert [s.total_executables for s in subs] == [2, 2, 2] + + # Each executable's measurement is the same {"00": 0.7, "11": 0.3} as the existing + # test_observables_no_inputs fixture, so the expectation should match 4364.36. + sub_results = [ + _build_sub_quantum_result( + sub, [[_make_exec_result(i) for i in range(sub.total_executables)]] + ) + for sub in subs + ] + + merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + composite = merged[0] + assert composite.observables is h + assert len(composite) == 6 + assert np.isclose(composite.expectation(), 4364.36) + + +def test_from_multiple_mixed_bindings_and_failures(circuit_rx_parametrized_fixture): + """A program set with multiple bindings, split across sub-tasks, with one + executable failing in a sub-task. Failures must land at the correct original + position in the merged result.""" + c1 = circuit_rx_parametrized_fixture + c2 = Circuit().rx(0, FreeParameter("phi")) + b1 = CircuitBinding(c1, {"theta": [0.1, 0.2, 0.3]}, observables=[X(0), Y(0)]) # 6 execs + b2 = CircuitBinding(c2, {"phi": [0.4, 0.5]}) # 2 execs, no observables + ps = ProgramSet([b1, b2]) + assert ps.total_executables == 8 + + subs, mapping = ps.split(5) + # Greedy pack with max=5: b1 classes (sizes 2,2,2) fill [2+2=4, +2>5 flush], so + # sub 0 = 2 b1 classes (4 execs), sub 1 = 1 b1 class (2 execs) + b2 (2 execs) = 4 execs. + assert [s.total_executables for s in subs] == [4, 4] + + def _failure(inputs_index): + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_executable_failure", + "version": "1", + }, + "inputsIndex": inputs_index, + "failureMetadata": { + "failureReason": "test failure", + "retryable": False, + "category": "DEVICE", + }, + } + + # Inject a failure at original index 5 (b1 ps=2, obs=1) which lives in sub 1. + sub_results = [] + failure_injected = False + for k, sub in enumerate(subs): + programs_execs = [] + for prog_idx, entry in enumerate(sub.entries): + num_execs = len(entry) if isinstance(entry, CircuitBinding) else 1 + execs = [] + for i in range(num_execs): + # Figure out this sub-executable's original index. Within sub k, + # j runs across all programs so we need a running counter. + j = ( + sum( + len(prev_entry) if isinstance(prev_entry, CircuitBinding) else 1 + for prev_entry in sub.entries[:prog_idx] + ) + + i + ) + if mapping[k][j] == 5: + execs.append(_failure(i)) + failure_injected = True + else: + execs.append(_make_exec_result(i)) + programs_execs.append(execs) + sub_results.append(_build_sub_quantum_result(sub, programs_execs)) + + assert failure_injected + merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + assert len(merged) == 2 + # Binding 0: 6 executables, position 5 is a failure. + assert len(merged[0]) == 6 + # Binding 1: 2 executables, all successful. + assert len(merged[1]) == 2 + from braket.task_result import ProgramSetExecutableFailure + + assert isinstance(merged[0].entries[5], ProgramSetExecutableFailure) + # All non-failure entries for binding 0 have the correct observables. + for i, entry in enumerate(merged[0].entries): + if isinstance(entry, MeasuredEntry): + expected_obs = b1.observables[i % len(b1.observables)] + assert entry.observable == expected_obs + # Binding 1 entries have no observable. + for entry in merged[1].entries: + if isinstance(entry, MeasuredEntry): + assert entry.observable is None + + +def test_from_multiple_validates_mapping_size(circuit_rx_parametrized_fixture): + binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) + ps = ProgramSet(binding) + sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) + # mapping has 1 entry for 1 sub-task, but size doesn't match ps.total_executables. + with pytest.raises(ValueError, match="Index map covers 1"): + ProgramSetQuantumTaskResult.from_multiple([sub_result], ps, [[0]]) + # Sub-task count doesn't match mapping's length. + with pytest.raises(ValueError, match="1 task results but 2 entries in index_map"): + ProgramSetQuantumTaskResult.from_multiple([sub_result], ps, [[0], [1]]) + + +@pytest.fixture +def circuit_rx_parametrized_fixture(): + return Circuit().rx(0, FreeParameter("theta")).cnot(0, 1) + + +def test_from_multiple_with_plain_circuit_entries(): + """from_multiple should handle plain Circuit entries (no inputs, no observables).""" + c1 = ghz_test(2) + c2 = ghz_test(1) + ps = ProgramSet([c1, c2]) + subs, mapping = ps.split(1) + assert [s.total_executables for s in subs] == [1, 1] + + sub_results = [_build_sub_quantum_result(sub, [[_make_exec_result(0)]]) for sub in subs] + + merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + assert len(merged) == 2 + assert len(merged[0]) == 1 + assert len(merged[1]) == 1 + assert merged[0].observables is None + assert merged[0].entries[0].observable is None + assert merged[0].entries[0].inputs is None + + +def test_from_multiple_rejects_sub_task_over_mapping(circuit_rx_parametrized_fixture): + """Sub-task has more executables than mapping[k] covers.""" + binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) + ps = ProgramSet(binding) + # Sub-task reports 2 executables, but mapping says there's only 1. + sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) + with pytest.raises(ValueError, match="produced more executables than index map"): + ProgramSetQuantumTaskResult.from_multiple( + [sub_result], + ProgramSet(CircuitBinding(circuit_rx_parametrized_fixture, {"theta": [0.1]})), + [[0]], + ) + + +def test_from_multiple_rejects_sub_task_under_mapping(circuit_rx_parametrized_fixture): + """Sub-task has fewer executables than mapping[k] covers.""" + binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) + ps = ProgramSet(binding) + # Sub-task reports only 1 executable, but mapping says there are 2. + sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0)]]) + with pytest.raises(ValueError, match="expected 2"): + ProgramSetQuantumTaskResult.from_multiple([sub_result], ps, [[0, 1]]) + + +def ghz_test(n): + """Local ghz helper so tests don't depend on program_set_test_utils.""" + circuit = Circuit().h(0) + for i in range(n - 1): + circuit.cnot(i, i + 1) + return circuit From 5a17aa7a705a6d002b0429f51e5a5277f664cc70 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Fri, 8 May 2026 17:15:13 -0700 Subject: [PATCH 3/9] rename --- src/braket/program_sets/program_set.py | 2 +- .../tasks/program_set_quantum_task_result.py | 2 +- .../test_program_set_quantum_task_result.py | 18 +++++++++--------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/braket/program_sets/program_set.py b/src/braket/program_sets/program_set.py index f3935b66e..a2b7b31cb 100644 --- a/src/braket/program_sets/program_set.py +++ b/src/braket/program_sets/program_set.py @@ -111,7 +111,7 @@ def enumerate_executables(self) -> Iterator[tuple[int, int, int]]: ``Sum`` Hamiltonian, ``observable_index`` ranges over the summands. This ordering is used by ``split`` to build its index map and by - ``ProgramSetQuantumTaskResult.from_multiple`` to merge results back into the original shape. + ``ProgramSetQuantumTaskResult.merge`` to merge results back into the original shape. Yields: tuple[int, int, int]: ``(binding_index, parameter_set_index, observable_index)``. diff --git a/src/braket/tasks/program_set_quantum_task_result.py b/src/braket/tasks/program_set_quantum_task_result.py index 02c1e6091..9313ee2d5 100644 --- a/src/braket/tasks/program_set_quantum_task_result.py +++ b/src/braket/tasks/program_set_quantum_task_result.py @@ -376,7 +376,7 @@ def from_object( ) @staticmethod - def from_multiple( + def merge( results: Sequence[ProgramSetQuantumTaskResult], program_set: ProgramSet, index_map: list[list[int]], diff --git a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py index 0d8fac6b1..886a4a22e 100644 --- a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py +++ b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py @@ -571,7 +571,7 @@ def test_from_multiple_single_sub_task_no_split_roundtrips(circuit_rx_parametriz ) reference = ProgramSetQuantumTaskResult.from_object(wire, ps) - merged = ProgramSetQuantumTaskResult.from_multiple([reference], ps, mapping) + merged = ProgramSetQuantumTaskResult.merge([reference], ps, mapping) assert len(merged) == len(reference) == 1 ref_composite = reference[0] @@ -608,7 +608,7 @@ def test_from_multiple_split_list_observables(circuit_rx_parametrized_fixture): for sub in subs ] - merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) assert len(merged) == 1 composite = merged[0] # The merged composite should have 4 MeasuredEntries in canonical order, each with @@ -653,7 +653,7 @@ def test_from_multiple_split_sum_hamiltonian_reconstructs_expectation( for sub in subs ] - merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) composite = merged[0] assert composite.observables is h assert len(composite) == 6 @@ -717,7 +717,7 @@ def _failure(inputs_index): sub_results.append(_build_sub_quantum_result(sub, programs_execs)) assert failure_injected - merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) assert len(merged) == 2 # Binding 0: 6 executables, position 5 is a failure. assert len(merged[0]) == 6 @@ -743,10 +743,10 @@ def test_from_multiple_validates_mapping_size(circuit_rx_parametrized_fixture): sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) # mapping has 1 entry for 1 sub-task, but size doesn't match ps.total_executables. with pytest.raises(ValueError, match="Index map covers 1"): - ProgramSetQuantumTaskResult.from_multiple([sub_result], ps, [[0]]) + ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0]]) # Sub-task count doesn't match mapping's length. with pytest.raises(ValueError, match="1 task results but 2 entries in index_map"): - ProgramSetQuantumTaskResult.from_multiple([sub_result], ps, [[0], [1]]) + ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0], [1]]) @pytest.fixture @@ -764,7 +764,7 @@ def test_from_multiple_with_plain_circuit_entries(): sub_results = [_build_sub_quantum_result(sub, [[_make_exec_result(0)]]) for sub in subs] - merged = ProgramSetQuantumTaskResult.from_multiple(sub_results, ps, mapping) + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) assert len(merged) == 2 assert len(merged[0]) == 1 assert len(merged[1]) == 1 @@ -780,7 +780,7 @@ def test_from_multiple_rejects_sub_task_over_mapping(circuit_rx_parametrized_fix # Sub-task reports 2 executables, but mapping says there's only 1. sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) with pytest.raises(ValueError, match="produced more executables than index map"): - ProgramSetQuantumTaskResult.from_multiple( + ProgramSetQuantumTaskResult.merge( [sub_result], ProgramSet(CircuitBinding(circuit_rx_parametrized_fixture, {"theta": [0.1]})), [[0]], @@ -794,7 +794,7 @@ def test_from_multiple_rejects_sub_task_under_mapping(circuit_rx_parametrized_fi # Sub-task reports only 1 executable, but mapping says there are 2. sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0)]]) with pytest.raises(ValueError, match="expected 2"): - ProgramSetQuantumTaskResult.from_multiple([sub_result], ps, [[0, 1]]) + ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0, 1]]) def ghz_test(n): From 0e2ad55cf1a78c94d71f2519fd9f6d1b35f25721 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Fri, 8 May 2026 17:27:01 -0700 Subject: [PATCH 4/9] Update program_set.py --- src/braket/program_sets/program_set.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/braket/program_sets/program_set.py b/src/braket/program_sets/program_set.py index a2b7b31cb..ee909e145 100644 --- a/src/braket/program_sets/program_set.py +++ b/src/braket/program_sets/program_set.py @@ -227,15 +227,15 @@ def _enumerate_triples(self, max_executables: int) -> list[_Triple]: return triples def _build_program_set(self, triples: list[_Triple]) -> tuple[ProgramSet, list[int]]: - entries: list[CircuitBinding | Circuit] = [] - sub_sub_map: list[int] = [] + entries = [] + sub_map = [] i = 0 while i < len(triples): head = triples[i] prog = self._programs[head.prog_idx] if head.param_set_index is None: entries.append(_apply_obs_slice(prog, head.obs_slice)) - sub_sub_map.extend(head.original_indices) + sub_map.extend(head.original_indices) i += 1 continue @@ -255,9 +255,9 @@ def _build_program_set(self, triples: list[_Triple]) -> tuple[ProgramSet, list[i ) entries.append(coalesced) for k in range(i, j + 1): - sub_sub_map.extend(triples[k].original_indices) + sub_map.extend(triples[k].original_indices) i = j + 1 - return ProgramSet(entries, self._shots_per_executable), sub_sub_map + return ProgramSet(entries, self._shots_per_executable), sub_map @staticmethod def zip( From aba672254f04b0e7c6074d5324a450303404d392 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Fri, 8 May 2026 18:41:32 -0700 Subject: [PATCH 5/9] rename --- src/braket/program_sets/program_set.py | 82 ++++++++--------- .../tasks/program_set_quantum_task_result.py | 87 ++++++++----------- 2 files changed, 79 insertions(+), 90 deletions(-) diff --git a/src/braket/program_sets/program_set.py b/src/braket/program_sets/program_set.py index ee909e145..73e092a0a 100644 --- a/src/braket/program_sets/program_set.py +++ b/src/braket/program_sets/program_set.py @@ -116,15 +116,14 @@ def enumerate_executables(self) -> Iterator[tuple[int, int, int]]: Yields: tuple[int, int, int]: ``(binding_index, parameter_set_index, observable_index)``. """ - for b_idx, prog in enumerate(self._programs): + for binding_idx, prog in enumerate(self._programs): if isinstance(prog, Circuit): - yield b_idx, 0, 0 + yield binding_idx, 0, 0 continue - num_ps = len(prog.input_sets) if prog.input_sets is not None else 1 num_obs = len(prog.observables) if prog.observables is not None else 1 - for ps_idx in range(num_ps): + for ps_idx in range(len(prog.input_sets) if prog.input_sets is not None else 1): for obs_idx in range(num_obs): - yield b_idx, ps_idx, obs_idx + yield binding_idx, ps_idx, obs_idx def split(self, max_executables: int) -> tuple[list[ProgramSet], list[list[int]]]: """Split this program set into program sets of at most ``max_executables`` executables, @@ -169,33 +168,32 @@ def split(self, max_executables: int) -> tuple[list[ProgramSet], list[list[int]] if self.total_executables <= max_executables: return [self], [list(range(self.total_executables))] - triples = self._enumerate_triples(max_executables) program_sets = [] index_map = [] current = [] current_size = 0 - for cls in triples: - if current and current_size + cls.size > max_executables: + for block in self._executable_blocks(max_executables): + if current and current_size + block.size > max_executables: sub, sub_map = self._build_program_set(current) program_sets.append(sub) index_map.append(sub_map) current = [] current_size = 0 - current.append(cls) - current_size += cls.size + current.append(block) + current_size += block.size sub, sub_map = self._build_program_set(current) program_sets.append(sub) index_map.append(sub_map) return program_sets, index_map - def _enumerate_triples(self, max_executables: int) -> list[_Triple]: - triples = [] + def _executable_blocks(self, max_executables: int) -> list[_ExecutableBlock]: + blocks = [] orig_idx = 0 for prog_idx, prog in enumerate(self._programs): if isinstance(prog, Circuit): - triples.append( - _Triple( + blocks.append( + _ExecutableBlock( prog_idx=prog_idx, param_set_index=None, obs_slice=None, @@ -206,16 +204,16 @@ def _enumerate_triples(self, max_executables: int) -> list[_Triple]: orig_idx += 1 continue - num_obs = len(prog.observables) if prog.observables is not None else 1 num_ps = len(prog.input_sets) if prog.input_sets is not None else 1 - ps_indices = list(range(num_ps)) if prog.input_sets is not None else [None] - obs_windows = _observable_windows(num_obs, max_executables) + obs_windows = _observable_windows( + len(prog.observables) if prog.observables is not None else 1, max_executables + ) split_observables = len(obs_windows) > 1 - for ps_idx in ps_indices: + for ps_idx in range(num_ps) if prog.input_sets is not None else [None]: for start, stop in obs_windows: size = stop - start - triples.append( - _Triple( + blocks.append( + _ExecutableBlock( prog_idx=prog_idx, param_set_index=ps_idx, obs_slice=slice(start, stop) if split_observables else None, @@ -224,14 +222,14 @@ def _enumerate_triples(self, max_executables: int) -> list[_Triple]: ) ) orig_idx += size - return triples + return blocks - def _build_program_set(self, triples: list[_Triple]) -> tuple[ProgramSet, list[int]]: + def _build_program_set(self, blocks: list[_ExecutableBlock]) -> tuple[ProgramSet, list[int]]: entries = [] sub_map = [] i = 0 - while i < len(triples): - head = triples[i] + while i < len(blocks): + head = blocks[i] prog = self._programs[head.prog_idx] if head.param_set_index is None: entries.append(_apply_obs_slice(prog, head.obs_slice)) @@ -241,21 +239,23 @@ def _build_program_set(self, triples: list[_Triple]) -> tuple[ProgramSet, list[i j = i while ( - j + 1 < len(triples) - and triples[j + 1].prog_idx == head.prog_idx - and triples[j + 1].obs_slice == triples[j].obs_slice - and triples[j + 1].param_set_index == triples[j].param_set_index + 1 + j + 1 < len(blocks) + and blocks[j + 1].prog_idx == head.prog_idx + and blocks[j + 1].obs_slice == blocks[j].obs_slice + and blocks[j + 1].param_set_index == blocks[j].param_set_index + 1 ): j += 1 - start, stop = head.param_set_index, triples[j].param_set_index + 1 - coalesced = CircuitBinding( - prog.circuit, - input_sets=prog.input_sets.as_list()[start:stop], - observables=_slice_observables(prog.observables, head.obs_slice), + start = head.param_set_index + stop = blocks[j].param_set_index + 1 + entries.append( + CircuitBinding( + prog.circuit, + input_sets=prog.input_sets.as_list()[start:stop], + observables=_slice_observables(prog.observables, head.obs_slice), + ) ) - entries.append(coalesced) for k in range(i, j + 1): - sub_map.extend(triples[k].original_indices) + sub_map.extend(blocks[k].original_indices) i = j + 1 return ProgramSet(entries, self._shots_per_executable), sub_map @@ -368,10 +368,10 @@ def __repr__(self): ) -@dataclass(frozen=True) -class _Triple: - """A contiguous run of executables sharing the same ``(circuit, observable list, - single parameter assignment)`` triple. +@dataclass +class _ExecutableBlock: + """Multi-index range for an equivalence class of executables sharing the same combination of + ``(circuit, observable list/Sum Hamiltonian, single parameter assignment)``. Attributes: prog_idx: Index of the originating program in ``ProgramSet.entries``. @@ -380,8 +380,8 @@ class _Triple: obs_slice: Slice into the originating observable list or ``Sum`` summands when observables were split to fit the budget; ``None`` means the full original observable list (or no observables). - size: Number of executables this triple represents (== ``len(original_indices)``). - original_indices: The indices of this triple's executables + size: Number of executables this block represents (== ``len(original_indices)``). + original_indices: The indices of this block's executables in the order of the original program set. """ diff --git a/src/braket/tasks/program_set_quantum_task_result.py b/src/braket/tasks/program_set_quantum_task_result.py index 9313ee2d5..99178171e 100644 --- a/src/braket/tasks/program_set_quantum_task_result.py +++ b/src/braket/tasks/program_set_quantum_task_result.py @@ -423,13 +423,10 @@ def merge( f"has {total_executables}" ) - binding_programs = [_binding_to_program(binding) for binding in program_set.entries] - triples = list(program_set.enumerate_executables()) + programs = [_binding_to_program(binding) for binding in program_set.entries] + executable_indices = list(program_set.enumerate_executables()) binding_executable_counts = [_count_executables(b) for b in program_set.entries] - - metas = [r.task_metadata for r in results] - first_num_execs = _num_executables_from_metadata(metas[0]) - shots_per_executable = metas[0].requestedShots // first_num_execs if first_num_execs else 0 + shots_per_executable = results[0].entries[0].shots_per_executable buffer = [None] * total_executables for k, result in enumerate(results): @@ -438,8 +435,8 @@ def merge( result=result, map_k=index_map[k], program_set=program_set, - binding_programs=binding_programs, - triples=triples, + programs=programs, + executable_indices=executable_indices, buffer=buffer, ) @@ -447,8 +444,8 @@ def merge( start = 0 for binding_idx, binding in enumerate(program_set.entries): count = binding_executable_counts[binding_idx] - program = binding_programs[binding_idx] - observables = None if isinstance(binding, Circuit) else binding.observables + program = programs[binding_idx] + observables = binding.observables if isinstance(binding, CircuitBinding) else None entries.append( CompositeEntry( entries=buffer[start : start + count], @@ -461,9 +458,29 @@ def merge( ) start += count + metas = [r.task_metadata for r in results] return ProgramSetQuantumTaskResult( entries=entries, - task_metadata=_aggregate_task_metadata(metas, program_set), + task_metadata=ProgramSetTaskMetadata( + id=";".join(meta.id for meta in metas), # Better way to do this? + deviceId=metas[0].deviceId, + requestedShots=sum(m.requestedShots for m in metas), + successfulShots=sum(m.successfulShots for m in metas), + programMetadata=[ + ProgramMetadata( + executables=[ + ProgramSetExecutableResultMetadata() + for _ in range(_count_executables(b)) + ] + ) + for b in program_set.entries + ], + deviceParameters=None, # TODO: find a way to fill this in + createdAt=min(m.createdAt for m in metas if m.createdAt), + endedAt=max(m.endedAt for m in metas if m.endedAt), + status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", + totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), + ), num_executables=total_executables, program_set=program_set, ) @@ -593,18 +610,13 @@ def _count_executables(binding: CircuitBinding | Circuit) -> int: return num_ps * num_obs -def _num_executables_from_metadata(metadata: ProgramSetTaskMetadata) -> int: - return sum(len(p.executables) for p in metadata.programMetadata) - - def _buffer_result( - *, k: int, result: ProgramSetQuantumTaskResult, map_k: list[int], program_set: ProgramSet, - binding_programs: list[Program], - triples: list[tuple[int, int, int]], + programs: list[Program], + executable_indices: list[tuple[int, int, int]], buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None], ) -> None: j = 0 @@ -616,11 +628,11 @@ def _buffer_result( "produced more executables than index map expects" ) orig_idx = map_k[j] - b_idx, ps_idx, obs_idx = triples[orig_idx] + binding_idx, ps_idx, obs_idx = executable_indices[orig_idx] buffer[orig_idx] = _convert_measured_entry( entry, - program_set.entries[b_idx], - binding_programs[b_idx], + program_set.entries[binding_idx], + programs[binding_idx], ps_idx, obs_idx, ) @@ -655,34 +667,11 @@ def _convert_measured_entry( num_obs = len(observables) orig_inputs_index = parameter_set_index * num_obs + observable_index program_inputs = original_program.inputs or {} - inputs = {key: value[orig_inputs_index] for key, value in program_inputs.items()} or None - return replace(entry, program=original_program.source, inputs=inputs, observable=observable) - - -def _aggregate_task_metadata( - metas: Sequence[ProgramSetTaskMetadata], program_set: ProgramSet -) -> ProgramSetTaskMetadata: - first = metas[0] - created_values = [m.createdAt for m in metas if m.createdAt] - ended_values = [m.endedAt for m in metas if m.endedAt] - return ProgramSetTaskMetadata( - id=";".join(meta.id for meta in metas), - deviceId=first.deviceId, - requestedShots=sum(m.requestedShots for m in metas), - successfulShots=sum(m.successfulShots for m in metas), - programMetadata=[ - ProgramMetadata( - executables=[ - ProgramSetExecutableResultMetadata() for _ in range(_count_executables(b)) - ] - ) - for b in program_set.entries - ], - deviceParameters=None, - createdAt=min(created_values) if created_values else None, - endedAt=max(ended_values) if ended_values else None, - status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", - totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), + return replace( + entry, + program=original_program.source, + inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None, + observable=observable, ) From 746790dc0fda3ec3bd24c8b750693450c3554e6b Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Sun, 10 May 2026 14:19:16 -0700 Subject: [PATCH 6/9] Remove results merging Will move to separate PR --- .../tasks/program_set_quantum_task_result.py | 202 +--------- .../test_program_set_quantum_task_result.py | 366 ------------------ 2 files changed, 4 insertions(+), 564 deletions(-) diff --git a/src/braket/tasks/program_set_quantum_task_result.py b/src/braket/tasks/program_set_quantum_task_result.py index 99178171e..9f7084c97 100644 --- a/src/braket/tasks/program_set_quantum_task_result.py +++ b/src/braket/tasks/program_set_quantum_task_result.py @@ -16,7 +16,7 @@ import warnings from collections import Counter from collections.abc import Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass import boto3 import numpy as np @@ -31,15 +31,10 @@ ProgramSetTaskMetadata, ProgramSetTaskResult, ) -from braket.task_result.program_set_executable_result_v1 import ( - ProgramSetExecutableResultMetadata, -) -from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata -from braket.circuits import Circuit, Observable +from braket.circuits import Observable from braket.circuits.observable import EULER_OBSERVABLE_PREFIX from braket.circuits.observables import Sum -from braket.circuits.serialization import IRType from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.tasks.measurement_utils import ( expectation_from_measurements, @@ -265,7 +260,7 @@ def _get_inputs(program: Program, observables: Sum | list[Observable] | None) -> def _get_executable_results( executable_results: Sequence[ ProgramSetExecutableResult | ProgramSetExecutableFailure | str - ], + ], program: Program, observables: Sum | list[Observable] | None, shots_per_executable: int, @@ -305,7 +300,7 @@ def _dispatch_executable_result( program=program.source, shots=shots_per_executable, inputs={k: v[result.inputsIndex] for k, v in (program.inputs or {}).items()} - or None, + or None, observable=( observables[result.inputsIndex % len(observables)] if observables else None ), @@ -375,116 +370,6 @@ def from_object( program_set=program_set, ) - @staticmethod - def merge( - results: Sequence[ProgramSetQuantumTaskResult], - program_set: ProgramSet, - index_map: list[list[int]], - ) -> ProgramSetQuantumTaskResult: - """Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running - each program set of ``program_set.split(...)``. - - ``index_map`` is the per-executable map returned alongside the program sets by - ``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``, - of the executable that the jth executable of the kth task represents. The kth task's - executables are read in order for its program set, namely across ``results[k].entries``, - and within each ``CompositeEntry`` across its ``entries``. - - The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had - been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``, - and ``MeasuredEntry`` objects in the order of the program. - - Expectation values and ``Sum`` Hamiltonian expectations are computed - for the original ``ProgramSet``. - - Args: - results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same - order as ``program_set.split``'s return. - program_set (ProgramSet): The original unsplit program set. - index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``. - - Returns: - ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``. - - Raises: - ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map`` - doesn't match ``program_set.total_executables``, or if any task produces a - different number of executables than its map expects. - """ - if len(results) != len(index_map): - raise ValueError( - f"Got {len(results)} task results but {len(index_map)} entries in index_map" - ) - total_executables = program_set.total_executables - total_mapped = sum(len(m) for m in index_map) - if total_mapped != total_executables: - raise ValueError( - f"Index map covers {total_mapped} executables but the original program set " - f"has {total_executables}" - ) - - programs = [_binding_to_program(binding) for binding in program_set.entries] - executable_indices = list(program_set.enumerate_executables()) - binding_executable_counts = [_count_executables(b) for b in program_set.entries] - shots_per_executable = results[0].entries[0].shots_per_executable - - buffer = [None] * total_executables - for k, result in enumerate(results): - _buffer_result( - k=k, - result=result, - map_k=index_map[k], - program_set=program_set, - programs=programs, - executable_indices=executable_indices, - buffer=buffer, - ) - - entries = [] - start = 0 - for binding_idx, binding in enumerate(program_set.entries): - count = binding_executable_counts[binding_idx] - program = programs[binding_idx] - observables = binding.observables if isinstance(binding, CircuitBinding) else None - entries.append( - CompositeEntry( - entries=buffer[start : start + count], - program=program, - inputs=CompositeEntry._get_inputs(program, observables), - observables=observables, - shots_per_executable=shots_per_executable, - additional_metadata=None, - ) - ) - start += count - - metas = [r.task_metadata for r in results] - return ProgramSetQuantumTaskResult( - entries=entries, - task_metadata=ProgramSetTaskMetadata( - id=";".join(meta.id for meta in metas), # Better way to do this? - deviceId=metas[0].deviceId, - requestedShots=sum(m.requestedShots for m in metas), - successfulShots=sum(m.successfulShots for m in metas), - programMetadata=[ - ProgramMetadata( - executables=[ - ProgramSetExecutableResultMetadata() - for _ in range(_count_executables(b)) - ] - ) - for b in program_set.entries - ], - deviceParameters=None, # TODO: find a way to fill this in - createdAt=min(m.createdAt for m in metas if m.createdAt), - endedAt=max(m.endedAt for m in metas if m.endedAt), - status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", - totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), - ), - num_executables=total_executables, - program_set=program_set, - ) - def __len__(self): return len(self.entries) @@ -596,85 +481,6 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int: return counter -def _binding_to_program(binding: CircuitBinding | Circuit) -> Program: - if isinstance(binding, Circuit): - return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None) - return binding.to_ir() - - -def _count_executables(binding: CircuitBinding | Circuit) -> int: - if isinstance(binding, Circuit): - return 1 - num_ps = len(binding.input_sets) if binding.input_sets is not None else 1 - num_obs = len(binding.observables) if binding.observables is not None else 1 - return num_ps * num_obs - - -def _buffer_result( - k: int, - result: ProgramSetQuantumTaskResult, - map_k: list[int], - program_set: ProgramSet, - programs: list[Program], - executable_indices: list[tuple[int, int, int]], - buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None], -) -> None: - j = 0 - for composite in result.entries: - for entry in composite.entries: - if j >= len(map_k): - raise ValueError( - f"t=Task {result.task_metadata.id} at index {k} " - "produced more executables than index map expects" - ) - orig_idx = map_k[j] - binding_idx, ps_idx, obs_idx = executable_indices[orig_idx] - buffer[orig_idx] = _convert_measured_entry( - entry, - program_set.entries[binding_idx], - programs[binding_idx], - ps_idx, - obs_idx, - ) - j += 1 - if j != len(map_k): - raise ValueError( - f"Task {result.task_metadata.id} at index {k} produced {j} executables " - f"but index map expected {len(map_k)}" - ) - - -def _convert_measured_entry( - entry: MeasuredEntry | ProgramSetExecutableFailure, - original_binding: CircuitBinding | Circuit, - original_program: Program, - parameter_set_index: int, - observable_index: int, -) -> MeasuredEntry | ProgramSetExecutableFailure: - if isinstance(entry, ProgramSetExecutableFailure): - return entry - if isinstance(original_binding, Circuit): - return replace(entry, program=original_program.source, inputs=None, observable=None) - observables = original_binding.observables - if observables is None: - observable: Observable | None = None - num_obs = 1 - elif isinstance(observables, Sum): - observable = observables.summands[observable_index] - num_obs = len(observables.summands) - else: - observable = observables[observable_index] - num_obs = len(observables) - orig_inputs_index = parameter_set_index * num_obs + observable_index - program_inputs = original_program.inputs or {} - return replace( - entry, - program=original_program.source, - inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None, - observable=observable, - ) - - def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str: """Retrieve the S3 object body. diff --git a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py index 886a4a22e..3580c4eb0 100644 --- a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py +++ b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py @@ -19,8 +19,6 @@ from braket.circuits import Circuit from braket.circuits.observables import X, Y, Z -from braket.circuits.serialization import IRType -from braket.ir.openqasm import Program from braket.parametric import FreeParameter from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.schema_common import BraketSchemaBase @@ -439,367 +437,3 @@ def test_dispatch_executable_result_with_none_inputs(execution_measurement_proba assert isinstance(measured_entry, MeasuredEntry) assert measured_entry.inputs is None assert measured_entry.probabilities == {"00": 0.7, "11": 0.3} - - -_SIM_METADATA_HEADER = { - "braketSchemaHeader": {"name": "braket.task_result.simulator_metadata", "version": "1"}, - "executionDuration": 50, -} -_DEVICE_PARAMS = { - "braketSchemaHeader": { - "name": "braket.device_schema.simulators.gate_model_simulator_device_parameters", - "version": "1", - }, - "paradigmParameters": { - "braketSchemaHeader": { - "name": "braket.device_schema.gate_model_parameters", - "version": "1", - }, - "qubitCount": 5, - "disableQubitRewiring": False, - }, -} - - -def _make_exec_result(inputs_index, probs=None): - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_executable_result", - "version": "1", - }, - "inputsIndex": inputs_index, - "measurementProbabilities": probs or {"00": 0.7, "11": 0.3}, - "measuredQubits": [0, 1], - } - - -def _make_program_result(program_dict, executable_dicts): - return { - "braketSchemaHeader": {"name": "braket.task_result.program_result", "version": "1"}, - "executableResults": executable_dicts, - "source": program_dict, - "additionalMetadata": {"simulatorMetadata": dict(_SIM_METADATA_HEADER)}, - } - - -def _make_task_metadata( - program_executable_counts, task_id="arn:aws:braket:::task/sub", shots_per_executable=40 -): - total = sum(program_executable_counts) - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_task_metadata", - "version": "1", - }, - "id": task_id, - "deviceId": "arn:aws:braket:::device/quantum-simulator/amazon/sv1", - "requestedShots": shots_per_executable * total, - "successfulShots": shots_per_executable * total, - "programMetadata": [ - {"executables": [{} for _ in range(n)]} for n in program_executable_counts - ], - "deviceParameters": dict(_DEVICE_PARAMS), - "createdAt": "2024-10-15T19:06:58.986Z", - "endedAt": "2024-10-15T19:07:00.382Z", - "status": "COMPLETED", - "totalFailedExecutables": 0, - } - - -def _make_task_result(program_results, metadata): - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_task_result", - "version": "1", - }, - "programResults": program_results, - "taskMetadata": metadata, - } - - -def _parse(d): - return BraketSchemaBase.parse_raw_schema(json.dumps(d)) - - -def _build_sub_quantum_result(sub_program_set, programs_execs, shots_per_executable=40): - """Build a :class:`ProgramSetQuantumTaskResult` for a sub-program-set by first - building a wire-format ``ProgramSetTaskResult`` and passing it through - :meth:`ProgramSetQuantumTaskResult.from_object`. - - Args: - sub_program_set: The sub-``ProgramSet`` whose run produced the result. - programs_execs: One list of exec-result dicts per entry in ``sub_program_set.entries``. - shots_per_executable: shots per executable, propagated to the metadata. - """ - program_results = [] - counts = [] - for entry, execs in zip(sub_program_set.entries, programs_execs, strict=True): - if isinstance(entry, CircuitBinding): - source_dict = entry.to_ir().dict() - else: - source_dict = Program(source=entry.to_ir(IRType.OPENQASM).source, inputs=None).dict() - program_results.append(_make_program_result(source_dict, execs)) - counts.append(len(execs)) - wire = _parse( - _make_task_result( - program_results, _make_task_metadata(counts, shots_per_executable=shots_per_executable) - ) - ) - return ProgramSetQuantumTaskResult.from_object(wire, sub_program_set) - - -def test_from_multiple_single_sub_task_no_split_roundtrips(circuit_rx_parametrized_fixture): - """If split returns [self], from_multiple should reproduce from_object's output.""" - binding = CircuitBinding( - circuit_rx_parametrized_fixture, - input_sets={"theta": [0.12, 2.1]}, - observables=10 * Z(0) + X(0) - 0.01 * Y(0) @ X(1), - ) - ps = ProgramSet(binding) - subs, mapping = ps.split(100) # fits, so one sub-task identical to ps. - assert subs == [ps] - - # Build a ProgramSetQuantumTaskResult that represents running this ps: the wire - # payload goes through from_object first. - sub_program = subs[0].to_ir().programs[0].dict() - execs = [_make_exec_result(i) for i in range(ps.total_executables)] - wire = _parse( - _make_task_result( - [_make_program_result(sub_program, execs)], - _make_task_metadata([ps.total_executables]), - ) - ) - reference = ProgramSetQuantumTaskResult.from_object(wire, ps) - - merged = ProgramSetQuantumTaskResult.merge([reference], ps, mapping) - - assert len(merged) == len(reference) == 1 - ref_composite = reference[0] - got_composite = merged[0] - assert len(got_composite) == len(ref_composite) - assert got_composite.program == ref_composite.program - assert got_composite.inputs == ref_composite.inputs - assert got_composite.observables == ref_composite.observables - for m_got, m_ref in zip(got_composite.entries, ref_composite.entries): - assert m_got.measured_qubits == m_ref.measured_qubits - assert m_got.probabilities == m_ref.probabilities - assert m_got.observable == m_ref.observable - assert m_got.inputs == m_ref.inputs - - -def test_from_multiple_split_list_observables(circuit_rx_parametrized_fixture): - """Split a binding with more observables than fit; scatter + regroup must - reconstruct the same CompositeEntry as running unsplit.""" - binding = CircuitBinding( - circuit_rx_parametrized_fixture, - input_sets={"theta": [0.12]}, - observables=[X(0), Y(0), Z(0), X(0) @ Y(1)], # 4 observables. - ) - ps = ProgramSet(binding) - subs, mapping = ps.split(2) # 4 > 2, so observables split into windows (0,2), (2,4). - assert [s.total_executables for s in subs] == [2, 2] - - # One sub-quantum-result per sub-program-set, built by running each through - # from_object on an inline wire payload. - sub_results = [ - _build_sub_quantum_result( - sub, [[_make_exec_result(i, {"00": 1.0}) for i in range(sub.total_executables)]] - ) - for sub in subs - ] - - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - assert len(merged) == 1 - composite = merged[0] - # The merged composite should have 4 MeasuredEntries in canonical order, each with - # the ORIGINAL binding's observable attached at that index. - assert len(composite) == 4 - for i, measured in enumerate(composite.entries): - assert isinstance(measured, MeasuredEntry) - assert measured.observable == binding.observables[i] - assert composite.inputs == ParameterSets({"theta": [0.12]}) - # task metadata was aggregated across sub-tasks. - assert merged.num_executables == 4 - assert merged.task_metadata.requestedShots == sum( - r.task_metadata.requestedShots for r in sub_results - ) - assert merged.task_metadata.successfulShots == sum( - r.task_metadata.successfulShots for r in sub_results - ) - - -def test_from_multiple_split_sum_hamiltonian_reconstructs_expectation( - circuit_rx_parametrized_fixture, -): - """Splitting a Sum Hamiltonian across multiple sub-tasks and then merging must - reconstruct the full expectation value, because scatter+regroup feeds the original - Sum back into ``_compute_expectations``.""" - # Same fixture as existing test_observables_no_inputs (with known expectation). - circuit = Circuit().h(0).cnot(0, 1) - h = 10000 * Z(0) + 1000 * X(0) - 100 * Z(0) + 10 * Z(1) + X(1) - 0.1 * Y(1) - binding = CircuitBinding(circuit, observables=h) - ps = ProgramSet(binding) - assert ps.total_executables == 6 - - subs, mapping = ps.split(2) # 6 > 2, so Sum splits into 3 windows of size 2. - assert [s.total_executables for s in subs] == [2, 2, 2] - - # Each executable's measurement is the same {"00": 0.7, "11": 0.3} as the existing - # test_observables_no_inputs fixture, so the expectation should match 4364.36. - sub_results = [ - _build_sub_quantum_result( - sub, [[_make_exec_result(i) for i in range(sub.total_executables)]] - ) - for sub in subs - ] - - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - composite = merged[0] - assert composite.observables is h - assert len(composite) == 6 - assert np.isclose(composite.expectation(), 4364.36) - - -def test_from_multiple_mixed_bindings_and_failures(circuit_rx_parametrized_fixture): - """A program set with multiple bindings, split across sub-tasks, with one - executable failing in a sub-task. Failures must land at the correct original - position in the merged result.""" - c1 = circuit_rx_parametrized_fixture - c2 = Circuit().rx(0, FreeParameter("phi")) - b1 = CircuitBinding(c1, {"theta": [0.1, 0.2, 0.3]}, observables=[X(0), Y(0)]) # 6 execs - b2 = CircuitBinding(c2, {"phi": [0.4, 0.5]}) # 2 execs, no observables - ps = ProgramSet([b1, b2]) - assert ps.total_executables == 8 - - subs, mapping = ps.split(5) - # Greedy pack with max=5: b1 classes (sizes 2,2,2) fill [2+2=4, +2>5 flush], so - # sub 0 = 2 b1 classes (4 execs), sub 1 = 1 b1 class (2 execs) + b2 (2 execs) = 4 execs. - assert [s.total_executables for s in subs] == [4, 4] - - def _failure(inputs_index): - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_executable_failure", - "version": "1", - }, - "inputsIndex": inputs_index, - "failureMetadata": { - "failureReason": "test failure", - "retryable": False, - "category": "DEVICE", - }, - } - - # Inject a failure at original index 5 (b1 ps=2, obs=1) which lives in sub 1. - sub_results = [] - failure_injected = False - for k, sub in enumerate(subs): - programs_execs = [] - for prog_idx, entry in enumerate(sub.entries): - num_execs = len(entry) if isinstance(entry, CircuitBinding) else 1 - execs = [] - for i in range(num_execs): - # Figure out this sub-executable's original index. Within sub k, - # j runs across all programs so we need a running counter. - j = ( - sum( - len(prev_entry) if isinstance(prev_entry, CircuitBinding) else 1 - for prev_entry in sub.entries[:prog_idx] - ) - + i - ) - if mapping[k][j] == 5: - execs.append(_failure(i)) - failure_injected = True - else: - execs.append(_make_exec_result(i)) - programs_execs.append(execs) - sub_results.append(_build_sub_quantum_result(sub, programs_execs)) - - assert failure_injected - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - assert len(merged) == 2 - # Binding 0: 6 executables, position 5 is a failure. - assert len(merged[0]) == 6 - # Binding 1: 2 executables, all successful. - assert len(merged[1]) == 2 - from braket.task_result import ProgramSetExecutableFailure - - assert isinstance(merged[0].entries[5], ProgramSetExecutableFailure) - # All non-failure entries for binding 0 have the correct observables. - for i, entry in enumerate(merged[0].entries): - if isinstance(entry, MeasuredEntry): - expected_obs = b1.observables[i % len(b1.observables)] - assert entry.observable == expected_obs - # Binding 1 entries have no observable. - for entry in merged[1].entries: - if isinstance(entry, MeasuredEntry): - assert entry.observable is None - - -def test_from_multiple_validates_mapping_size(circuit_rx_parametrized_fixture): - binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) - ps = ProgramSet(binding) - sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) - # mapping has 1 entry for 1 sub-task, but size doesn't match ps.total_executables. - with pytest.raises(ValueError, match="Index map covers 1"): - ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0]]) - # Sub-task count doesn't match mapping's length. - with pytest.raises(ValueError, match="1 task results but 2 entries in index_map"): - ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0], [1]]) - - -@pytest.fixture -def circuit_rx_parametrized_fixture(): - return Circuit().rx(0, FreeParameter("theta")).cnot(0, 1) - - -def test_from_multiple_with_plain_circuit_entries(): - """from_multiple should handle plain Circuit entries (no inputs, no observables).""" - c1 = ghz_test(2) - c2 = ghz_test(1) - ps = ProgramSet([c1, c2]) - subs, mapping = ps.split(1) - assert [s.total_executables for s in subs] == [1, 1] - - sub_results = [_build_sub_quantum_result(sub, [[_make_exec_result(0)]]) for sub in subs] - - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - assert len(merged) == 2 - assert len(merged[0]) == 1 - assert len(merged[1]) == 1 - assert merged[0].observables is None - assert merged[0].entries[0].observable is None - assert merged[0].entries[0].inputs is None - - -def test_from_multiple_rejects_sub_task_over_mapping(circuit_rx_parametrized_fixture): - """Sub-task has more executables than mapping[k] covers.""" - binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) - ps = ProgramSet(binding) - # Sub-task reports 2 executables, but mapping says there's only 1. - sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) - with pytest.raises(ValueError, match="produced more executables than index map"): - ProgramSetQuantumTaskResult.merge( - [sub_result], - ProgramSet(CircuitBinding(circuit_rx_parametrized_fixture, {"theta": [0.1]})), - [[0]], - ) - - -def test_from_multiple_rejects_sub_task_under_mapping(circuit_rx_parametrized_fixture): - """Sub-task has fewer executables than mapping[k] covers.""" - binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) - ps = ProgramSet(binding) - # Sub-task reports only 1 executable, but mapping says there are 2. - sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0)]]) - with pytest.raises(ValueError, match="expected 2"): - ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0, 1]]) - - -def ghz_test(n): - """Local ghz helper so tests don't depend on program_set_test_utils.""" - circuit = Circuit().h(0) - for i in range(n - 1): - circuit.cnot(i, i + 1) - return circuit From cb558790f89507d5a32bce0aeb90b102fa732401 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Sun, 10 May 2026 14:27:38 -0700 Subject: [PATCH 7/9] reformat --- src/braket/tasks/program_set_quantum_task_result.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/braket/tasks/program_set_quantum_task_result.py b/src/braket/tasks/program_set_quantum_task_result.py index 9f7084c97..c4c37b00e 100644 --- a/src/braket/tasks/program_set_quantum_task_result.py +++ b/src/braket/tasks/program_set_quantum_task_result.py @@ -260,7 +260,7 @@ def _get_inputs(program: Program, observables: Sum | list[Observable] | None) -> def _get_executable_results( executable_results: Sequence[ ProgramSetExecutableResult | ProgramSetExecutableFailure | str - ], + ], program: Program, observables: Sum | list[Observable] | None, shots_per_executable: int, @@ -300,7 +300,7 @@ def _dispatch_executable_result( program=program.source, shots=shots_per_executable, inputs={k: v[result.inputsIndex] for k, v in (program.inputs or {}).items()} - or None, + or None, observable=( observables[result.inputsIndex % len(observables)] if observables else None ), From fe2cc3cf1d7e862b35c82f24479da5fb4aaa17f9 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Sun, 10 May 2026 14:37:40 -0700 Subject: [PATCH 8/9] merging, will recreate cleanly in new PR --- .../tasks/program_set_quantum_task_result.py | 198 +++++++++- .../test_program_set_quantum_task_result.py | 366 ++++++++++++++++++ 2 files changed, 562 insertions(+), 2 deletions(-) diff --git a/src/braket/tasks/program_set_quantum_task_result.py b/src/braket/tasks/program_set_quantum_task_result.py index c4c37b00e..99178171e 100644 --- a/src/braket/tasks/program_set_quantum_task_result.py +++ b/src/braket/tasks/program_set_quantum_task_result.py @@ -16,7 +16,7 @@ import warnings from collections import Counter from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, replace import boto3 import numpy as np @@ -31,10 +31,15 @@ ProgramSetTaskMetadata, ProgramSetTaskResult, ) +from braket.task_result.program_set_executable_result_v1 import ( + ProgramSetExecutableResultMetadata, +) +from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata -from braket.circuits import Observable +from braket.circuits import Circuit, Observable from braket.circuits.observable import EULER_OBSERVABLE_PREFIX from braket.circuits.observables import Sum +from braket.circuits.serialization import IRType from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.tasks.measurement_utils import ( expectation_from_measurements, @@ -370,6 +375,116 @@ def from_object( program_set=program_set, ) + @staticmethod + def merge( + results: Sequence[ProgramSetQuantumTaskResult], + program_set: ProgramSet, + index_map: list[list[int]], + ) -> ProgramSetQuantumTaskResult: + """Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running + each program set of ``program_set.split(...)``. + + ``index_map`` is the per-executable map returned alongside the program sets by + ``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``, + of the executable that the jth executable of the kth task represents. The kth task's + executables are read in order for its program set, namely across ``results[k].entries``, + and within each ``CompositeEntry`` across its ``entries``. + + The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had + been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``, + and ``MeasuredEntry`` objects in the order of the program. + + Expectation values and ``Sum`` Hamiltonian expectations are computed + for the original ``ProgramSet``. + + Args: + results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same + order as ``program_set.split``'s return. + program_set (ProgramSet): The original unsplit program set. + index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``. + + Returns: + ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``. + + Raises: + ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map`` + doesn't match ``program_set.total_executables``, or if any task produces a + different number of executables than its map expects. + """ + if len(results) != len(index_map): + raise ValueError( + f"Got {len(results)} task results but {len(index_map)} entries in index_map" + ) + total_executables = program_set.total_executables + total_mapped = sum(len(m) for m in index_map) + if total_mapped != total_executables: + raise ValueError( + f"Index map covers {total_mapped} executables but the original program set " + f"has {total_executables}" + ) + + programs = [_binding_to_program(binding) for binding in program_set.entries] + executable_indices = list(program_set.enumerate_executables()) + binding_executable_counts = [_count_executables(b) for b in program_set.entries] + shots_per_executable = results[0].entries[0].shots_per_executable + + buffer = [None] * total_executables + for k, result in enumerate(results): + _buffer_result( + k=k, + result=result, + map_k=index_map[k], + program_set=program_set, + programs=programs, + executable_indices=executable_indices, + buffer=buffer, + ) + + entries = [] + start = 0 + for binding_idx, binding in enumerate(program_set.entries): + count = binding_executable_counts[binding_idx] + program = programs[binding_idx] + observables = binding.observables if isinstance(binding, CircuitBinding) else None + entries.append( + CompositeEntry( + entries=buffer[start : start + count], + program=program, + inputs=CompositeEntry._get_inputs(program, observables), + observables=observables, + shots_per_executable=shots_per_executable, + additional_metadata=None, + ) + ) + start += count + + metas = [r.task_metadata for r in results] + return ProgramSetQuantumTaskResult( + entries=entries, + task_metadata=ProgramSetTaskMetadata( + id=";".join(meta.id for meta in metas), # Better way to do this? + deviceId=metas[0].deviceId, + requestedShots=sum(m.requestedShots for m in metas), + successfulShots=sum(m.successfulShots for m in metas), + programMetadata=[ + ProgramMetadata( + executables=[ + ProgramSetExecutableResultMetadata() + for _ in range(_count_executables(b)) + ] + ) + for b in program_set.entries + ], + deviceParameters=None, # TODO: find a way to fill this in + createdAt=min(m.createdAt for m in metas if m.createdAt), + endedAt=max(m.endedAt for m in metas if m.endedAt), + status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", + totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), + ), + num_executables=total_executables, + program_set=program_set, + ) + def __len__(self): return len(self.entries) @@ -481,6 +596,85 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int: return counter +def _binding_to_program(binding: CircuitBinding | Circuit) -> Program: + if isinstance(binding, Circuit): + return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None) + return binding.to_ir() + + +def _count_executables(binding: CircuitBinding | Circuit) -> int: + if isinstance(binding, Circuit): + return 1 + num_ps = len(binding.input_sets) if binding.input_sets is not None else 1 + num_obs = len(binding.observables) if binding.observables is not None else 1 + return num_ps * num_obs + + +def _buffer_result( + k: int, + result: ProgramSetQuantumTaskResult, + map_k: list[int], + program_set: ProgramSet, + programs: list[Program], + executable_indices: list[tuple[int, int, int]], + buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None], +) -> None: + j = 0 + for composite in result.entries: + for entry in composite.entries: + if j >= len(map_k): + raise ValueError( + f"t=Task {result.task_metadata.id} at index {k} " + "produced more executables than index map expects" + ) + orig_idx = map_k[j] + binding_idx, ps_idx, obs_idx = executable_indices[orig_idx] + buffer[orig_idx] = _convert_measured_entry( + entry, + program_set.entries[binding_idx], + programs[binding_idx], + ps_idx, + obs_idx, + ) + j += 1 + if j != len(map_k): + raise ValueError( + f"Task {result.task_metadata.id} at index {k} produced {j} executables " + f"but index map expected {len(map_k)}" + ) + + +def _convert_measured_entry( + entry: MeasuredEntry | ProgramSetExecutableFailure, + original_binding: CircuitBinding | Circuit, + original_program: Program, + parameter_set_index: int, + observable_index: int, +) -> MeasuredEntry | ProgramSetExecutableFailure: + if isinstance(entry, ProgramSetExecutableFailure): + return entry + if isinstance(original_binding, Circuit): + return replace(entry, program=original_program.source, inputs=None, observable=None) + observables = original_binding.observables + if observables is None: + observable: Observable | None = None + num_obs = 1 + elif isinstance(observables, Sum): + observable = observables.summands[observable_index] + num_obs = len(observables.summands) + else: + observable = observables[observable_index] + num_obs = len(observables) + orig_inputs_index = parameter_set_index * num_obs + observable_index + program_inputs = original_program.inputs or {} + return replace( + entry, + program=original_program.source, + inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None, + observable=observable, + ) + + def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str: """Retrieve the S3 object body. diff --git a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py index 3580c4eb0..886a4a22e 100644 --- a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py +++ b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py @@ -19,6 +19,8 @@ from braket.circuits import Circuit from braket.circuits.observables import X, Y, Z +from braket.circuits.serialization import IRType +from braket.ir.openqasm import Program from braket.parametric import FreeParameter from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.schema_common import BraketSchemaBase @@ -437,3 +439,367 @@ def test_dispatch_executable_result_with_none_inputs(execution_measurement_proba assert isinstance(measured_entry, MeasuredEntry) assert measured_entry.inputs is None assert measured_entry.probabilities == {"00": 0.7, "11": 0.3} + + +_SIM_METADATA_HEADER = { + "braketSchemaHeader": {"name": "braket.task_result.simulator_metadata", "version": "1"}, + "executionDuration": 50, +} +_DEVICE_PARAMS = { + "braketSchemaHeader": { + "name": "braket.device_schema.simulators.gate_model_simulator_device_parameters", + "version": "1", + }, + "paradigmParameters": { + "braketSchemaHeader": { + "name": "braket.device_schema.gate_model_parameters", + "version": "1", + }, + "qubitCount": 5, + "disableQubitRewiring": False, + }, +} + + +def _make_exec_result(inputs_index, probs=None): + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_executable_result", + "version": "1", + }, + "inputsIndex": inputs_index, + "measurementProbabilities": probs or {"00": 0.7, "11": 0.3}, + "measuredQubits": [0, 1], + } + + +def _make_program_result(program_dict, executable_dicts): + return { + "braketSchemaHeader": {"name": "braket.task_result.program_result", "version": "1"}, + "executableResults": executable_dicts, + "source": program_dict, + "additionalMetadata": {"simulatorMetadata": dict(_SIM_METADATA_HEADER)}, + } + + +def _make_task_metadata( + program_executable_counts, task_id="arn:aws:braket:::task/sub", shots_per_executable=40 +): + total = sum(program_executable_counts) + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_task_metadata", + "version": "1", + }, + "id": task_id, + "deviceId": "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + "requestedShots": shots_per_executable * total, + "successfulShots": shots_per_executable * total, + "programMetadata": [ + {"executables": [{} for _ in range(n)]} for n in program_executable_counts + ], + "deviceParameters": dict(_DEVICE_PARAMS), + "createdAt": "2024-10-15T19:06:58.986Z", + "endedAt": "2024-10-15T19:07:00.382Z", + "status": "COMPLETED", + "totalFailedExecutables": 0, + } + + +def _make_task_result(program_results, metadata): + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_task_result", + "version": "1", + }, + "programResults": program_results, + "taskMetadata": metadata, + } + + +def _parse(d): + return BraketSchemaBase.parse_raw_schema(json.dumps(d)) + + +def _build_sub_quantum_result(sub_program_set, programs_execs, shots_per_executable=40): + """Build a :class:`ProgramSetQuantumTaskResult` for a sub-program-set by first + building a wire-format ``ProgramSetTaskResult`` and passing it through + :meth:`ProgramSetQuantumTaskResult.from_object`. + + Args: + sub_program_set: The sub-``ProgramSet`` whose run produced the result. + programs_execs: One list of exec-result dicts per entry in ``sub_program_set.entries``. + shots_per_executable: shots per executable, propagated to the metadata. + """ + program_results = [] + counts = [] + for entry, execs in zip(sub_program_set.entries, programs_execs, strict=True): + if isinstance(entry, CircuitBinding): + source_dict = entry.to_ir().dict() + else: + source_dict = Program(source=entry.to_ir(IRType.OPENQASM).source, inputs=None).dict() + program_results.append(_make_program_result(source_dict, execs)) + counts.append(len(execs)) + wire = _parse( + _make_task_result( + program_results, _make_task_metadata(counts, shots_per_executable=shots_per_executable) + ) + ) + return ProgramSetQuantumTaskResult.from_object(wire, sub_program_set) + + +def test_from_multiple_single_sub_task_no_split_roundtrips(circuit_rx_parametrized_fixture): + """If split returns [self], from_multiple should reproduce from_object's output.""" + binding = CircuitBinding( + circuit_rx_parametrized_fixture, + input_sets={"theta": [0.12, 2.1]}, + observables=10 * Z(0) + X(0) - 0.01 * Y(0) @ X(1), + ) + ps = ProgramSet(binding) + subs, mapping = ps.split(100) # fits, so one sub-task identical to ps. + assert subs == [ps] + + # Build a ProgramSetQuantumTaskResult that represents running this ps: the wire + # payload goes through from_object first. + sub_program = subs[0].to_ir().programs[0].dict() + execs = [_make_exec_result(i) for i in range(ps.total_executables)] + wire = _parse( + _make_task_result( + [_make_program_result(sub_program, execs)], + _make_task_metadata([ps.total_executables]), + ) + ) + reference = ProgramSetQuantumTaskResult.from_object(wire, ps) + + merged = ProgramSetQuantumTaskResult.merge([reference], ps, mapping) + + assert len(merged) == len(reference) == 1 + ref_composite = reference[0] + got_composite = merged[0] + assert len(got_composite) == len(ref_composite) + assert got_composite.program == ref_composite.program + assert got_composite.inputs == ref_composite.inputs + assert got_composite.observables == ref_composite.observables + for m_got, m_ref in zip(got_composite.entries, ref_composite.entries): + assert m_got.measured_qubits == m_ref.measured_qubits + assert m_got.probabilities == m_ref.probabilities + assert m_got.observable == m_ref.observable + assert m_got.inputs == m_ref.inputs + + +def test_from_multiple_split_list_observables(circuit_rx_parametrized_fixture): + """Split a binding with more observables than fit; scatter + regroup must + reconstruct the same CompositeEntry as running unsplit.""" + binding = CircuitBinding( + circuit_rx_parametrized_fixture, + input_sets={"theta": [0.12]}, + observables=[X(0), Y(0), Z(0), X(0) @ Y(1)], # 4 observables. + ) + ps = ProgramSet(binding) + subs, mapping = ps.split(2) # 4 > 2, so observables split into windows (0,2), (2,4). + assert [s.total_executables for s in subs] == [2, 2] + + # One sub-quantum-result per sub-program-set, built by running each through + # from_object on an inline wire payload. + sub_results = [ + _build_sub_quantum_result( + sub, [[_make_exec_result(i, {"00": 1.0}) for i in range(sub.total_executables)]] + ) + for sub in subs + ] + + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) + assert len(merged) == 1 + composite = merged[0] + # The merged composite should have 4 MeasuredEntries in canonical order, each with + # the ORIGINAL binding's observable attached at that index. + assert len(composite) == 4 + for i, measured in enumerate(composite.entries): + assert isinstance(measured, MeasuredEntry) + assert measured.observable == binding.observables[i] + assert composite.inputs == ParameterSets({"theta": [0.12]}) + # task metadata was aggregated across sub-tasks. + assert merged.num_executables == 4 + assert merged.task_metadata.requestedShots == sum( + r.task_metadata.requestedShots for r in sub_results + ) + assert merged.task_metadata.successfulShots == sum( + r.task_metadata.successfulShots for r in sub_results + ) + + +def test_from_multiple_split_sum_hamiltonian_reconstructs_expectation( + circuit_rx_parametrized_fixture, +): + """Splitting a Sum Hamiltonian across multiple sub-tasks and then merging must + reconstruct the full expectation value, because scatter+regroup feeds the original + Sum back into ``_compute_expectations``.""" + # Same fixture as existing test_observables_no_inputs (with known expectation). + circuit = Circuit().h(0).cnot(0, 1) + h = 10000 * Z(0) + 1000 * X(0) - 100 * Z(0) + 10 * Z(1) + X(1) - 0.1 * Y(1) + binding = CircuitBinding(circuit, observables=h) + ps = ProgramSet(binding) + assert ps.total_executables == 6 + + subs, mapping = ps.split(2) # 6 > 2, so Sum splits into 3 windows of size 2. + assert [s.total_executables for s in subs] == [2, 2, 2] + + # Each executable's measurement is the same {"00": 0.7, "11": 0.3} as the existing + # test_observables_no_inputs fixture, so the expectation should match 4364.36. + sub_results = [ + _build_sub_quantum_result( + sub, [[_make_exec_result(i) for i in range(sub.total_executables)]] + ) + for sub in subs + ] + + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) + composite = merged[0] + assert composite.observables is h + assert len(composite) == 6 + assert np.isclose(composite.expectation(), 4364.36) + + +def test_from_multiple_mixed_bindings_and_failures(circuit_rx_parametrized_fixture): + """A program set with multiple bindings, split across sub-tasks, with one + executable failing in a sub-task. Failures must land at the correct original + position in the merged result.""" + c1 = circuit_rx_parametrized_fixture + c2 = Circuit().rx(0, FreeParameter("phi")) + b1 = CircuitBinding(c1, {"theta": [0.1, 0.2, 0.3]}, observables=[X(0), Y(0)]) # 6 execs + b2 = CircuitBinding(c2, {"phi": [0.4, 0.5]}) # 2 execs, no observables + ps = ProgramSet([b1, b2]) + assert ps.total_executables == 8 + + subs, mapping = ps.split(5) + # Greedy pack with max=5: b1 classes (sizes 2,2,2) fill [2+2=4, +2>5 flush], so + # sub 0 = 2 b1 classes (4 execs), sub 1 = 1 b1 class (2 execs) + b2 (2 execs) = 4 execs. + assert [s.total_executables for s in subs] == [4, 4] + + def _failure(inputs_index): + return { + "braketSchemaHeader": { + "name": "braket.task_result.program_set_executable_failure", + "version": "1", + }, + "inputsIndex": inputs_index, + "failureMetadata": { + "failureReason": "test failure", + "retryable": False, + "category": "DEVICE", + }, + } + + # Inject a failure at original index 5 (b1 ps=2, obs=1) which lives in sub 1. + sub_results = [] + failure_injected = False + for k, sub in enumerate(subs): + programs_execs = [] + for prog_idx, entry in enumerate(sub.entries): + num_execs = len(entry) if isinstance(entry, CircuitBinding) else 1 + execs = [] + for i in range(num_execs): + # Figure out this sub-executable's original index. Within sub k, + # j runs across all programs so we need a running counter. + j = ( + sum( + len(prev_entry) if isinstance(prev_entry, CircuitBinding) else 1 + for prev_entry in sub.entries[:prog_idx] + ) + + i + ) + if mapping[k][j] == 5: + execs.append(_failure(i)) + failure_injected = True + else: + execs.append(_make_exec_result(i)) + programs_execs.append(execs) + sub_results.append(_build_sub_quantum_result(sub, programs_execs)) + + assert failure_injected + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) + assert len(merged) == 2 + # Binding 0: 6 executables, position 5 is a failure. + assert len(merged[0]) == 6 + # Binding 1: 2 executables, all successful. + assert len(merged[1]) == 2 + from braket.task_result import ProgramSetExecutableFailure + + assert isinstance(merged[0].entries[5], ProgramSetExecutableFailure) + # All non-failure entries for binding 0 have the correct observables. + for i, entry in enumerate(merged[0].entries): + if isinstance(entry, MeasuredEntry): + expected_obs = b1.observables[i % len(b1.observables)] + assert entry.observable == expected_obs + # Binding 1 entries have no observable. + for entry in merged[1].entries: + if isinstance(entry, MeasuredEntry): + assert entry.observable is None + + +def test_from_multiple_validates_mapping_size(circuit_rx_parametrized_fixture): + binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) + ps = ProgramSet(binding) + sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) + # mapping has 1 entry for 1 sub-task, but size doesn't match ps.total_executables. + with pytest.raises(ValueError, match="Index map covers 1"): + ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0]]) + # Sub-task count doesn't match mapping's length. + with pytest.raises(ValueError, match="1 task results but 2 entries in index_map"): + ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0], [1]]) + + +@pytest.fixture +def circuit_rx_parametrized_fixture(): + return Circuit().rx(0, FreeParameter("theta")).cnot(0, 1) + + +def test_from_multiple_with_plain_circuit_entries(): + """from_multiple should handle plain Circuit entries (no inputs, no observables).""" + c1 = ghz_test(2) + c2 = ghz_test(1) + ps = ProgramSet([c1, c2]) + subs, mapping = ps.split(1) + assert [s.total_executables for s in subs] == [1, 1] + + sub_results = [_build_sub_quantum_result(sub, [[_make_exec_result(0)]]) for sub in subs] + + merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) + assert len(merged) == 2 + assert len(merged[0]) == 1 + assert len(merged[1]) == 1 + assert merged[0].observables is None + assert merged[0].entries[0].observable is None + assert merged[0].entries[0].inputs is None + + +def test_from_multiple_rejects_sub_task_over_mapping(circuit_rx_parametrized_fixture): + """Sub-task has more executables than mapping[k] covers.""" + binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) + ps = ProgramSet(binding) + # Sub-task reports 2 executables, but mapping says there's only 1. + sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) + with pytest.raises(ValueError, match="produced more executables than index map"): + ProgramSetQuantumTaskResult.merge( + [sub_result], + ProgramSet(CircuitBinding(circuit_rx_parametrized_fixture, {"theta": [0.1]})), + [[0]], + ) + + +def test_from_multiple_rejects_sub_task_under_mapping(circuit_rx_parametrized_fixture): + """Sub-task has fewer executables than mapping[k] covers.""" + binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) + ps = ProgramSet(binding) + # Sub-task reports only 1 executable, but mapping says there are 2. + sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0)]]) + with pytest.raises(ValueError, match="expected 2"): + ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0, 1]]) + + +def ghz_test(n): + """Local ghz helper so tests don't depend on program_set_test_utils.""" + circuit = Circuit().h(0) + for i in range(n - 1): + circuit.cnot(i, i + 1) + return circuit From 2774cdffe93035c9611b79d3ad550047d8020cae Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Sun, 10 May 2026 14:41:31 -0700 Subject: [PATCH 9/9] Remove results merging --- .../tasks/program_set_quantum_task_result.py | 198 +--------- .../test_program_set_quantum_task_result.py | 366 ------------------ 2 files changed, 2 insertions(+), 562 deletions(-) diff --git a/src/braket/tasks/program_set_quantum_task_result.py b/src/braket/tasks/program_set_quantum_task_result.py index 99178171e..c4c37b00e 100644 --- a/src/braket/tasks/program_set_quantum_task_result.py +++ b/src/braket/tasks/program_set_quantum_task_result.py @@ -16,7 +16,7 @@ import warnings from collections import Counter from collections.abc import Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass import boto3 import numpy as np @@ -31,15 +31,10 @@ ProgramSetTaskMetadata, ProgramSetTaskResult, ) -from braket.task_result.program_set_executable_result_v1 import ( - ProgramSetExecutableResultMetadata, -) -from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata -from braket.circuits import Circuit, Observable +from braket.circuits import Observable from braket.circuits.observable import EULER_OBSERVABLE_PREFIX from braket.circuits.observables import Sum -from braket.circuits.serialization import IRType from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.tasks.measurement_utils import ( expectation_from_measurements, @@ -375,116 +370,6 @@ def from_object( program_set=program_set, ) - @staticmethod - def merge( - results: Sequence[ProgramSetQuantumTaskResult], - program_set: ProgramSet, - index_map: list[list[int]], - ) -> ProgramSetQuantumTaskResult: - """Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running - each program set of ``program_set.split(...)``. - - ``index_map`` is the per-executable map returned alongside the program sets by - ``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``, - of the executable that the jth executable of the kth task represents. The kth task's - executables are read in order for its program set, namely across ``results[k].entries``, - and within each ``CompositeEntry`` across its ``entries``. - - The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had - been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``, - and ``MeasuredEntry`` objects in the order of the program. - - Expectation values and ``Sum`` Hamiltonian expectations are computed - for the original ``ProgramSet``. - - Args: - results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same - order as ``program_set.split``'s return. - program_set (ProgramSet): The original unsplit program set. - index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``. - - Returns: - ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``. - - Raises: - ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map`` - doesn't match ``program_set.total_executables``, or if any task produces a - different number of executables than its map expects. - """ - if len(results) != len(index_map): - raise ValueError( - f"Got {len(results)} task results but {len(index_map)} entries in index_map" - ) - total_executables = program_set.total_executables - total_mapped = sum(len(m) for m in index_map) - if total_mapped != total_executables: - raise ValueError( - f"Index map covers {total_mapped} executables but the original program set " - f"has {total_executables}" - ) - - programs = [_binding_to_program(binding) for binding in program_set.entries] - executable_indices = list(program_set.enumerate_executables()) - binding_executable_counts = [_count_executables(b) for b in program_set.entries] - shots_per_executable = results[0].entries[0].shots_per_executable - - buffer = [None] * total_executables - for k, result in enumerate(results): - _buffer_result( - k=k, - result=result, - map_k=index_map[k], - program_set=program_set, - programs=programs, - executable_indices=executable_indices, - buffer=buffer, - ) - - entries = [] - start = 0 - for binding_idx, binding in enumerate(program_set.entries): - count = binding_executable_counts[binding_idx] - program = programs[binding_idx] - observables = binding.observables if isinstance(binding, CircuitBinding) else None - entries.append( - CompositeEntry( - entries=buffer[start : start + count], - program=program, - inputs=CompositeEntry._get_inputs(program, observables), - observables=observables, - shots_per_executable=shots_per_executable, - additional_metadata=None, - ) - ) - start += count - - metas = [r.task_metadata for r in results] - return ProgramSetQuantumTaskResult( - entries=entries, - task_metadata=ProgramSetTaskMetadata( - id=";".join(meta.id for meta in metas), # Better way to do this? - deviceId=metas[0].deviceId, - requestedShots=sum(m.requestedShots for m in metas), - successfulShots=sum(m.successfulShots for m in metas), - programMetadata=[ - ProgramMetadata( - executables=[ - ProgramSetExecutableResultMetadata() - for _ in range(_count_executables(b)) - ] - ) - for b in program_set.entries - ], - deviceParameters=None, # TODO: find a way to fill this in - createdAt=min(m.createdAt for m in metas if m.createdAt), - endedAt=max(m.endedAt for m in metas if m.endedAt), - status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", - totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), - ), - num_executables=total_executables, - program_set=program_set, - ) - def __len__(self): return len(self.entries) @@ -596,85 +481,6 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int: return counter -def _binding_to_program(binding: CircuitBinding | Circuit) -> Program: - if isinstance(binding, Circuit): - return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None) - return binding.to_ir() - - -def _count_executables(binding: CircuitBinding | Circuit) -> int: - if isinstance(binding, Circuit): - return 1 - num_ps = len(binding.input_sets) if binding.input_sets is not None else 1 - num_obs = len(binding.observables) if binding.observables is not None else 1 - return num_ps * num_obs - - -def _buffer_result( - k: int, - result: ProgramSetQuantumTaskResult, - map_k: list[int], - program_set: ProgramSet, - programs: list[Program], - executable_indices: list[tuple[int, int, int]], - buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None], -) -> None: - j = 0 - for composite in result.entries: - for entry in composite.entries: - if j >= len(map_k): - raise ValueError( - f"t=Task {result.task_metadata.id} at index {k} " - "produced more executables than index map expects" - ) - orig_idx = map_k[j] - binding_idx, ps_idx, obs_idx = executable_indices[orig_idx] - buffer[orig_idx] = _convert_measured_entry( - entry, - program_set.entries[binding_idx], - programs[binding_idx], - ps_idx, - obs_idx, - ) - j += 1 - if j != len(map_k): - raise ValueError( - f"Task {result.task_metadata.id} at index {k} produced {j} executables " - f"but index map expected {len(map_k)}" - ) - - -def _convert_measured_entry( - entry: MeasuredEntry | ProgramSetExecutableFailure, - original_binding: CircuitBinding | Circuit, - original_program: Program, - parameter_set_index: int, - observable_index: int, -) -> MeasuredEntry | ProgramSetExecutableFailure: - if isinstance(entry, ProgramSetExecutableFailure): - return entry - if isinstance(original_binding, Circuit): - return replace(entry, program=original_program.source, inputs=None, observable=None) - observables = original_binding.observables - if observables is None: - observable: Observable | None = None - num_obs = 1 - elif isinstance(observables, Sum): - observable = observables.summands[observable_index] - num_obs = len(observables.summands) - else: - observable = observables[observable_index] - num_obs = len(observables) - orig_inputs_index = parameter_set_index * num_obs + observable_index - program_inputs = original_program.inputs or {} - return replace( - entry, - program=original_program.source, - inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None, - observable=observable, - ) - - def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str: """Retrieve the S3 object body. diff --git a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py index 886a4a22e..3580c4eb0 100644 --- a/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py +++ b/test/unit_tests/braket/tasks/test_program_set_quantum_task_result.py @@ -19,8 +19,6 @@ from braket.circuits import Circuit from braket.circuits.observables import X, Y, Z -from braket.circuits.serialization import IRType -from braket.ir.openqasm import Program from braket.parametric import FreeParameter from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet from braket.schema_common import BraketSchemaBase @@ -439,367 +437,3 @@ def test_dispatch_executable_result_with_none_inputs(execution_measurement_proba assert isinstance(measured_entry, MeasuredEntry) assert measured_entry.inputs is None assert measured_entry.probabilities == {"00": 0.7, "11": 0.3} - - -_SIM_METADATA_HEADER = { - "braketSchemaHeader": {"name": "braket.task_result.simulator_metadata", "version": "1"}, - "executionDuration": 50, -} -_DEVICE_PARAMS = { - "braketSchemaHeader": { - "name": "braket.device_schema.simulators.gate_model_simulator_device_parameters", - "version": "1", - }, - "paradigmParameters": { - "braketSchemaHeader": { - "name": "braket.device_schema.gate_model_parameters", - "version": "1", - }, - "qubitCount": 5, - "disableQubitRewiring": False, - }, -} - - -def _make_exec_result(inputs_index, probs=None): - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_executable_result", - "version": "1", - }, - "inputsIndex": inputs_index, - "measurementProbabilities": probs or {"00": 0.7, "11": 0.3}, - "measuredQubits": [0, 1], - } - - -def _make_program_result(program_dict, executable_dicts): - return { - "braketSchemaHeader": {"name": "braket.task_result.program_result", "version": "1"}, - "executableResults": executable_dicts, - "source": program_dict, - "additionalMetadata": {"simulatorMetadata": dict(_SIM_METADATA_HEADER)}, - } - - -def _make_task_metadata( - program_executable_counts, task_id="arn:aws:braket:::task/sub", shots_per_executable=40 -): - total = sum(program_executable_counts) - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_task_metadata", - "version": "1", - }, - "id": task_id, - "deviceId": "arn:aws:braket:::device/quantum-simulator/amazon/sv1", - "requestedShots": shots_per_executable * total, - "successfulShots": shots_per_executable * total, - "programMetadata": [ - {"executables": [{} for _ in range(n)]} for n in program_executable_counts - ], - "deviceParameters": dict(_DEVICE_PARAMS), - "createdAt": "2024-10-15T19:06:58.986Z", - "endedAt": "2024-10-15T19:07:00.382Z", - "status": "COMPLETED", - "totalFailedExecutables": 0, - } - - -def _make_task_result(program_results, metadata): - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_task_result", - "version": "1", - }, - "programResults": program_results, - "taskMetadata": metadata, - } - - -def _parse(d): - return BraketSchemaBase.parse_raw_schema(json.dumps(d)) - - -def _build_sub_quantum_result(sub_program_set, programs_execs, shots_per_executable=40): - """Build a :class:`ProgramSetQuantumTaskResult` for a sub-program-set by first - building a wire-format ``ProgramSetTaskResult`` and passing it through - :meth:`ProgramSetQuantumTaskResult.from_object`. - - Args: - sub_program_set: The sub-``ProgramSet`` whose run produced the result. - programs_execs: One list of exec-result dicts per entry in ``sub_program_set.entries``. - shots_per_executable: shots per executable, propagated to the metadata. - """ - program_results = [] - counts = [] - for entry, execs in zip(sub_program_set.entries, programs_execs, strict=True): - if isinstance(entry, CircuitBinding): - source_dict = entry.to_ir().dict() - else: - source_dict = Program(source=entry.to_ir(IRType.OPENQASM).source, inputs=None).dict() - program_results.append(_make_program_result(source_dict, execs)) - counts.append(len(execs)) - wire = _parse( - _make_task_result( - program_results, _make_task_metadata(counts, shots_per_executable=shots_per_executable) - ) - ) - return ProgramSetQuantumTaskResult.from_object(wire, sub_program_set) - - -def test_from_multiple_single_sub_task_no_split_roundtrips(circuit_rx_parametrized_fixture): - """If split returns [self], from_multiple should reproduce from_object's output.""" - binding = CircuitBinding( - circuit_rx_parametrized_fixture, - input_sets={"theta": [0.12, 2.1]}, - observables=10 * Z(0) + X(0) - 0.01 * Y(0) @ X(1), - ) - ps = ProgramSet(binding) - subs, mapping = ps.split(100) # fits, so one sub-task identical to ps. - assert subs == [ps] - - # Build a ProgramSetQuantumTaskResult that represents running this ps: the wire - # payload goes through from_object first. - sub_program = subs[0].to_ir().programs[0].dict() - execs = [_make_exec_result(i) for i in range(ps.total_executables)] - wire = _parse( - _make_task_result( - [_make_program_result(sub_program, execs)], - _make_task_metadata([ps.total_executables]), - ) - ) - reference = ProgramSetQuantumTaskResult.from_object(wire, ps) - - merged = ProgramSetQuantumTaskResult.merge([reference], ps, mapping) - - assert len(merged) == len(reference) == 1 - ref_composite = reference[0] - got_composite = merged[0] - assert len(got_composite) == len(ref_composite) - assert got_composite.program == ref_composite.program - assert got_composite.inputs == ref_composite.inputs - assert got_composite.observables == ref_composite.observables - for m_got, m_ref in zip(got_composite.entries, ref_composite.entries): - assert m_got.measured_qubits == m_ref.measured_qubits - assert m_got.probabilities == m_ref.probabilities - assert m_got.observable == m_ref.observable - assert m_got.inputs == m_ref.inputs - - -def test_from_multiple_split_list_observables(circuit_rx_parametrized_fixture): - """Split a binding with more observables than fit; scatter + regroup must - reconstruct the same CompositeEntry as running unsplit.""" - binding = CircuitBinding( - circuit_rx_parametrized_fixture, - input_sets={"theta": [0.12]}, - observables=[X(0), Y(0), Z(0), X(0) @ Y(1)], # 4 observables. - ) - ps = ProgramSet(binding) - subs, mapping = ps.split(2) # 4 > 2, so observables split into windows (0,2), (2,4). - assert [s.total_executables for s in subs] == [2, 2] - - # One sub-quantum-result per sub-program-set, built by running each through - # from_object on an inline wire payload. - sub_results = [ - _build_sub_quantum_result( - sub, [[_make_exec_result(i, {"00": 1.0}) for i in range(sub.total_executables)]] - ) - for sub in subs - ] - - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - assert len(merged) == 1 - composite = merged[0] - # The merged composite should have 4 MeasuredEntries in canonical order, each with - # the ORIGINAL binding's observable attached at that index. - assert len(composite) == 4 - for i, measured in enumerate(composite.entries): - assert isinstance(measured, MeasuredEntry) - assert measured.observable == binding.observables[i] - assert composite.inputs == ParameterSets({"theta": [0.12]}) - # task metadata was aggregated across sub-tasks. - assert merged.num_executables == 4 - assert merged.task_metadata.requestedShots == sum( - r.task_metadata.requestedShots for r in sub_results - ) - assert merged.task_metadata.successfulShots == sum( - r.task_metadata.successfulShots for r in sub_results - ) - - -def test_from_multiple_split_sum_hamiltonian_reconstructs_expectation( - circuit_rx_parametrized_fixture, -): - """Splitting a Sum Hamiltonian across multiple sub-tasks and then merging must - reconstruct the full expectation value, because scatter+regroup feeds the original - Sum back into ``_compute_expectations``.""" - # Same fixture as existing test_observables_no_inputs (with known expectation). - circuit = Circuit().h(0).cnot(0, 1) - h = 10000 * Z(0) + 1000 * X(0) - 100 * Z(0) + 10 * Z(1) + X(1) - 0.1 * Y(1) - binding = CircuitBinding(circuit, observables=h) - ps = ProgramSet(binding) - assert ps.total_executables == 6 - - subs, mapping = ps.split(2) # 6 > 2, so Sum splits into 3 windows of size 2. - assert [s.total_executables for s in subs] == [2, 2, 2] - - # Each executable's measurement is the same {"00": 0.7, "11": 0.3} as the existing - # test_observables_no_inputs fixture, so the expectation should match 4364.36. - sub_results = [ - _build_sub_quantum_result( - sub, [[_make_exec_result(i) for i in range(sub.total_executables)]] - ) - for sub in subs - ] - - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - composite = merged[0] - assert composite.observables is h - assert len(composite) == 6 - assert np.isclose(composite.expectation(), 4364.36) - - -def test_from_multiple_mixed_bindings_and_failures(circuit_rx_parametrized_fixture): - """A program set with multiple bindings, split across sub-tasks, with one - executable failing in a sub-task. Failures must land at the correct original - position in the merged result.""" - c1 = circuit_rx_parametrized_fixture - c2 = Circuit().rx(0, FreeParameter("phi")) - b1 = CircuitBinding(c1, {"theta": [0.1, 0.2, 0.3]}, observables=[X(0), Y(0)]) # 6 execs - b2 = CircuitBinding(c2, {"phi": [0.4, 0.5]}) # 2 execs, no observables - ps = ProgramSet([b1, b2]) - assert ps.total_executables == 8 - - subs, mapping = ps.split(5) - # Greedy pack with max=5: b1 classes (sizes 2,2,2) fill [2+2=4, +2>5 flush], so - # sub 0 = 2 b1 classes (4 execs), sub 1 = 1 b1 class (2 execs) + b2 (2 execs) = 4 execs. - assert [s.total_executables for s in subs] == [4, 4] - - def _failure(inputs_index): - return { - "braketSchemaHeader": { - "name": "braket.task_result.program_set_executable_failure", - "version": "1", - }, - "inputsIndex": inputs_index, - "failureMetadata": { - "failureReason": "test failure", - "retryable": False, - "category": "DEVICE", - }, - } - - # Inject a failure at original index 5 (b1 ps=2, obs=1) which lives in sub 1. - sub_results = [] - failure_injected = False - for k, sub in enumerate(subs): - programs_execs = [] - for prog_idx, entry in enumerate(sub.entries): - num_execs = len(entry) if isinstance(entry, CircuitBinding) else 1 - execs = [] - for i in range(num_execs): - # Figure out this sub-executable's original index. Within sub k, - # j runs across all programs so we need a running counter. - j = ( - sum( - len(prev_entry) if isinstance(prev_entry, CircuitBinding) else 1 - for prev_entry in sub.entries[:prog_idx] - ) - + i - ) - if mapping[k][j] == 5: - execs.append(_failure(i)) - failure_injected = True - else: - execs.append(_make_exec_result(i)) - programs_execs.append(execs) - sub_results.append(_build_sub_quantum_result(sub, programs_execs)) - - assert failure_injected - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - assert len(merged) == 2 - # Binding 0: 6 executables, position 5 is a failure. - assert len(merged[0]) == 6 - # Binding 1: 2 executables, all successful. - assert len(merged[1]) == 2 - from braket.task_result import ProgramSetExecutableFailure - - assert isinstance(merged[0].entries[5], ProgramSetExecutableFailure) - # All non-failure entries for binding 0 have the correct observables. - for i, entry in enumerate(merged[0].entries): - if isinstance(entry, MeasuredEntry): - expected_obs = b1.observables[i % len(b1.observables)] - assert entry.observable == expected_obs - # Binding 1 entries have no observable. - for entry in merged[1].entries: - if isinstance(entry, MeasuredEntry): - assert entry.observable is None - - -def test_from_multiple_validates_mapping_size(circuit_rx_parametrized_fixture): - binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) - ps = ProgramSet(binding) - sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) - # mapping has 1 entry for 1 sub-task, but size doesn't match ps.total_executables. - with pytest.raises(ValueError, match="Index map covers 1"): - ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0]]) - # Sub-task count doesn't match mapping's length. - with pytest.raises(ValueError, match="1 task results but 2 entries in index_map"): - ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0], [1]]) - - -@pytest.fixture -def circuit_rx_parametrized_fixture(): - return Circuit().rx(0, FreeParameter("theta")).cnot(0, 1) - - -def test_from_multiple_with_plain_circuit_entries(): - """from_multiple should handle plain Circuit entries (no inputs, no observables).""" - c1 = ghz_test(2) - c2 = ghz_test(1) - ps = ProgramSet([c1, c2]) - subs, mapping = ps.split(1) - assert [s.total_executables for s in subs] == [1, 1] - - sub_results = [_build_sub_quantum_result(sub, [[_make_exec_result(0)]]) for sub in subs] - - merged = ProgramSetQuantumTaskResult.merge(sub_results, ps, mapping) - assert len(merged) == 2 - assert len(merged[0]) == 1 - assert len(merged[1]) == 1 - assert merged[0].observables is None - assert merged[0].entries[0].observable is None - assert merged[0].entries[0].inputs is None - - -def test_from_multiple_rejects_sub_task_over_mapping(circuit_rx_parametrized_fixture): - """Sub-task has more executables than mapping[k] covers.""" - binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) - ps = ProgramSet(binding) - # Sub-task reports 2 executables, but mapping says there's only 1. - sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0), _make_exec_result(1)]]) - with pytest.raises(ValueError, match="produced more executables than index map"): - ProgramSetQuantumTaskResult.merge( - [sub_result], - ProgramSet(CircuitBinding(circuit_rx_parametrized_fixture, {"theta": [0.1]})), - [[0]], - ) - - -def test_from_multiple_rejects_sub_task_under_mapping(circuit_rx_parametrized_fixture): - """Sub-task has fewer executables than mapping[k] covers.""" - binding = CircuitBinding(circuit_rx_parametrized_fixture, input_sets={"theta": [0.1, 0.2]}) - ps = ProgramSet(binding) - # Sub-task reports only 1 executable, but mapping says there are 2. - sub_result = _build_sub_quantum_result(ps, [[_make_exec_result(0)]]) - with pytest.raises(ValueError, match="expected 2"): - ProgramSetQuantumTaskResult.merge([sub_result], ps, [[0, 1]]) - - -def ghz_test(n): - """Local ghz helper so tests don't depend on program_set_test_utils.""" - circuit = Circuit().h(0) - for i in range(n - 1): - circuit.cnot(i, i + 1) - return circuit