Skip to content

Commit df90a6e

Browse files
committed
feat: enhance CUDA device handling and update sample vector documentation
1 parent be21049 commit df90a6e

3 files changed

Lines changed: 33 additions & 10 deletions

File tree

qdp/qdp-python/qumat_qdp/api.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,15 @@ def _run_throughput_pytorch(self) -> ThroughputResult:
188188

189189
from qumat_qdp.torch_ref import encode
190190

191-
device = f"cuda:{self._device_id}" if torch.cuda.is_available() else "cpu"
191+
if torch.cuda.is_available():
192+
if self._device_id < 0 or self._device_id >= torch.cuda.device_count():
193+
raise ValueError(
194+
f"Invalid CUDA device_id {self._device_id}; "
195+
f"{torch.cuda.device_count()} device(s) available."
196+
)
197+
device = f"cuda:{self._device_id}"
198+
else:
199+
device = "cpu"
192200
# _validate() guarantees these are not None.
193201
assert self._num_qubits is not None
194202
assert self._total_batches is not None
@@ -205,9 +213,11 @@ def _run_throughput_pytorch(self) -> ThroughputResult:
205213
else:
206214
sample_dim = 1 << num_qubits
207215

208-
# Generate all batch data upfront.
209-
batches = []
210-
for b in range(self._total_batches + self._warmup_batches):
216+
# Pre-generate a small pool of batch tensors and cycle through them
217+
# to keep memory bounded at high qubit counts while still varying data.
218+
pool_size = min(8, self._total_batches + self._warmup_batches)
219+
pool: list[torch.Tensor] = []
220+
for _ in range(pool_size):
211221
if encoding_method == "basis":
212222
data = torch.randint(
213223
0, 1 << num_qubits, (batch_size,), device=device
@@ -216,18 +226,18 @@ def _run_throughput_pytorch(self) -> ThroughputResult:
216226
data = torch.randn(
217227
batch_size, sample_dim, dtype=torch.float64, device=device
218228
)
219-
batches.append(data)
229+
pool.append(data)
220230

221231
# Warmup.
222232
for b in range(self._warmup_batches):
223-
encode(batches[b], num_qubits, encoding_method, device=device)
233+
encode(pool[b % pool_size], num_qubits, encoding_method, device=device)
224234
if device.startswith("cuda"):
225235
torch.cuda.synchronize()
226236

227237
# Timed run.
228238
start = time.perf_counter()
229-
for b in range(self._warmup_batches, len(batches)):
230-
encode(batches[b], num_qubits, encoding_method, device=device)
239+
for b in range(self._total_batches):
240+
encode(pool[b % pool_size], num_qubits, encoding_method, device=device)
231241
if device.startswith("cuda"):
232242
torch.cuda.synchronize()
233243
duration = time.perf_counter() - start

qdp/qdp-python/qumat_qdp/loader.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ def _validate_loader_args(
8383

8484

8585
def _build_sample(seed: int, vector_len: int, encoding_method: str) -> list[float]:
86-
"""Build a single deterministic sample vector (mirrors benchmark/utils.py:build_sample)."""
86+
"""Build a single deterministic sample vector for the given encoding method.
87+
88+
Supports amplitude, angle, basis, and iqp (iqp uses the same mask-and-scale
89+
logic as amplitude).
90+
"""
8791
import numpy as np
8892

8993
if encoding_method == "basis":
@@ -337,7 +341,15 @@ def _create_pytorch_iterator(self, use_synthetic: bool) -> Iterator[object]:
337341

338342
from qumat_qdp.torch_ref import encode
339343

340-
device = f"cuda:{self._device_id}" if torch.cuda.is_available() else "cpu"
344+
if torch.cuda.is_available():
345+
if self._device_id < 0 or self._device_id >= torch.cuda.device_count():
346+
raise ValueError(
347+
f"Invalid CUDA device_id {self._device_id}; "
348+
f"{torch.cuda.device_count()} device(s) available."
349+
)
350+
device = f"cuda:{self._device_id}"
351+
else:
352+
device = "cpu"
341353

342354
if use_synthetic:
343355
return self._pytorch_synthetic_iter(torch, encode, device)

testing/qdp_python/test_torch_ref.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def _require_qdp(self):
365365
pytest.importorskip("_qdp")
366366

367367
@pytest.mark.gpu
368+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
368369
@pytest.mark.parametrize("encoding", ["amplitude", "angle", "basis", "iqp"])
369370
def test_encoding_matches_rust(self, encoding):
370371
import _qdp

0 commit comments

Comments
 (0)