@@ -188,7 +188,15 @@ def _run_throughput_pytorch(self) -> ThroughputResult:
188188
189189 from qumat_qdp .torch_ref import encode
190190
191- device = f"cuda:{ self ._device_id } " if torch .cuda .is_available () else "cpu"
191+ if torch .cuda .is_available ():
192+ if self ._device_id < 0 or self ._device_id >= torch .cuda .device_count ():
193+ raise ValueError (
194+ f"Invalid CUDA device_id { self ._device_id } ; "
195+ f"{ torch .cuda .device_count ()} device(s) available."
196+ )
197+ device = f"cuda:{ self ._device_id } "
198+ else :
199+ device = "cpu"
192200 # _validate() guarantees these are not None.
193201 assert self ._num_qubits is not None
194202 assert self ._total_batches is not None
@@ -205,9 +213,11 @@ def _run_throughput_pytorch(self) -> ThroughputResult:
205213 else :
206214 sample_dim = 1 << num_qubits
207215
208- # Generate all batch data upfront.
209- batches = []
210- for b in range (self ._total_batches + self ._warmup_batches ):
216+ # Pre-generate a small pool of batch tensors and cycle through them
217+ # to keep memory bounded at high qubit counts while still varying data.
218+ pool_size = min (8 , self ._total_batches + self ._warmup_batches )
219+ pool : list [torch .Tensor ] = []
220+ for _ in range (pool_size ):
211221 if encoding_method == "basis" :
212222 data = torch .randint (
213223 0 , 1 << num_qubits , (batch_size ,), device = device
@@ -216,18 +226,18 @@ def _run_throughput_pytorch(self) -> ThroughputResult:
216226 data = torch .randn (
217227 batch_size , sample_dim , dtype = torch .float64 , device = device
218228 )
219- batches .append (data )
229+ pool .append (data )
220230
221231 # Warmup.
222232 for b in range (self ._warmup_batches ):
223- encode (batches [ b ], num_qubits , encoding_method , device = device )
233+ encode (pool [ b % pool_size ], num_qubits , encoding_method , device = device )
224234 if device .startswith ("cuda" ):
225235 torch .cuda .synchronize ()
226236
227237 # Timed run.
228238 start = time .perf_counter ()
229- for b in range (self ._warmup_batches , len ( batches ) ):
230- encode (batches [ b ], num_qubits , encoding_method , device = device )
239+ for b in range (self ._total_batches ):
240+ encode (pool [ b % pool_size ], num_qubits , encoding_method , device = device )
231241 if device .startswith ("cuda" ):
232242 torch .cuda .synchronize ()
233243 duration = time .perf_counter () - start
0 commit comments