@@ -457,6 +457,216 @@ impl QuantumEncoder for AmplitudeEncoder {
457457 Ok ( batch_state_vector)
458458 }
459459
460+ /// Encode multiple samples in a single GPU allocation and kernel launch for f32 inputs
461+ #[ cfg( target_os = "linux" ) ]
462+ fn encode_batch_f32 (
463+ & self ,
464+ device : & Arc < CudaDevice > ,
465+ batch_data : & [ f32 ] ,
466+ num_samples : usize ,
467+ sample_size : usize ,
468+ num_qubits : usize ,
469+ ) -> Result < GpuStateVector > {
470+ crate :: profile_scope!( "AmplitudeEncoder::encode_batch_f32" ) ;
471+
472+ // Validate inputs. Wait, Preprocessor::validate_batch currently takes f64...
473+ // We will just do a basic length check if f32 validation is missing.
474+ let state_len = 1 << num_qubits;
475+ if batch_data. len ( ) != num_samples * sample_size {
476+ return Err ( MahoutError :: InvalidInput ( "batch_data length mismatch" . into ( ) ) ) ;
477+ }
478+
479+ let batch_state_vector = {
480+ crate :: profile_scope!( "GPU::AllocBatch_f32" ) ;
481+ GpuStateVector :: new_batch ( device, num_samples, num_qubits, Precision :: Float32 ) ?
482+ } ;
483+
484+ // Upload input data to GPU
485+ let input_batch_gpu = {
486+ crate :: profile_scope!( "GPU::H2D_InputBatch_f32" ) ;
487+ device. htod_sync_copy ( batch_data) . map_err ( |e| {
488+ MahoutError :: MemoryAllocation ( format ! ( "Failed to upload batch input: {:?}" , e) )
489+ } ) ?
490+ } ;
491+
492+ // Compute inverse norms on GPU using warp-reduced kernel
493+ let inv_norms_gpu = {
494+ crate :: profile_scope!( "GPU::BatchNormKernel_f32" ) ;
495+ use cudarc:: driver:: DevicePtrMut ;
496+ let mut buffer = device. alloc_zeros :: < f32 > ( num_samples) . map_err ( |e| {
497+ MahoutError :: MemoryAllocation ( format ! ( "Failed to allocate norm buffer: {:?}" , e) )
498+ } ) ?;
499+
500+ let ret = unsafe {
501+ launch_l2_norm_batch_f32 (
502+ * input_batch_gpu. device_ptr ( ) as * const f32 ,
503+ num_samples,
504+ sample_size,
505+ * buffer. device_ptr_mut ( ) as * mut f32 ,
506+ std:: ptr:: null_mut ( ) , // default stream
507+ )
508+ } ;
509+
510+ if ret != 0 {
511+ return Err ( MahoutError :: KernelLaunch ( format ! (
512+ "Norm reduction kernel failed: {} ({})" ,
513+ ret,
514+ cuda_error_to_string( ret)
515+ ) ) ) ;
516+ }
517+ buffer
518+ } ;
519+
520+ // Validate norms on host
521+ {
522+ crate :: profile_scope!( "GPU::NormValidation_f32" ) ;
523+ let host_inv_norms = device
524+ . dtoh_sync_copy ( & inv_norms_gpu)
525+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
526+
527+ if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
528+ return Err ( MahoutError :: InvalidInput (
529+ "One or more samples have zero or invalid norm" . to_string ( ) ,
530+ ) ) ;
531+ }
532+ }
533+
534+ // Launch batch kernel
535+ {
536+ crate :: profile_scope!( "GPU::BatchKernelLaunch_f32" ) ;
537+ use cudarc:: driver:: DevicePtr ;
538+ let state_ptr = batch_state_vector. ptr_f32 ( ) . ok_or_else ( || {
539+ MahoutError :: InvalidInput (
540+ "Batch state vector precision mismatch (expected float32 buffer)" . to_string ( ) ,
541+ )
542+ } ) ?;
543+ let ret = unsafe {
544+ launch_amplitude_encode_batch_f32 (
545+ * input_batch_gpu. device_ptr ( ) as * const f32 ,
546+ state_ptr as * mut c_void ,
547+ * inv_norms_gpu. device_ptr ( ) as * const f32 ,
548+ num_samples,
549+ sample_size,
550+ state_len,
551+ std:: ptr:: null_mut ( ) , // default stream
552+ )
553+ } ;
554+
555+ if ret != 0 {
556+ return Err ( MahoutError :: KernelLaunch ( format ! (
557+ "Batch kernel launch failed: {} ({})" ,
558+ ret,
559+ cuda_error_to_string( ret)
560+ ) ) ) ;
561+ }
562+ }
563+
564+ {
565+ crate :: profile_scope!( "GPU::Synchronize" ) ;
566+ device
567+ . synchronize ( )
568+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Sync failed: {:?}" , e) ) ) ?;
569+ }
570+
571+ Ok ( batch_state_vector)
572+ }
573+
574+ #[ cfg( target_os = "linux" ) ]
575+ unsafe fn encode_batch_from_gpu_ptr_f32 (
576+ & self ,
577+ device : & Arc < CudaDevice > ,
578+ input_batch_d : * const c_void ,
579+ num_samples : usize ,
580+ sample_size : usize ,
581+ num_qubits : usize ,
582+ stream : * mut c_void ,
583+ ) -> Result < GpuStateVector > {
584+ let state_len = 1 << num_qubits;
585+ if sample_size == 0 {
586+ return Err ( MahoutError :: InvalidInput (
587+ "Sample size cannot be zero" . into ( ) ,
588+ ) ) ;
589+ }
590+ if sample_size > state_len {
591+ return Err ( MahoutError :: InvalidInput ( format ! (
592+ "Sample size {} exceeds state vector size {} (2^{} qubits)" ,
593+ sample_size, state_len, num_qubits
594+ ) ) ) ;
595+ }
596+ let input_batch_d = input_batch_d as * const f32 ;
597+ let batch_state_vector = {
598+ crate :: profile_scope!( "GPU::AllocBatch_f32" ) ;
599+ GpuStateVector :: new_batch ( device, num_samples, num_qubits, Precision :: Float32 ) ?
600+ } ;
601+ let inv_norms_gpu = {
602+ crate :: profile_scope!( "GPU::BatchNormKernel_f32" ) ;
603+ use cudarc:: driver:: DevicePtrMut ;
604+ let mut buffer = device. alloc_zeros :: < f32 > ( num_samples) . map_err ( |e| {
605+ MahoutError :: MemoryAllocation ( format ! ( "Failed to allocate norm buffer: {:?}" , e) )
606+ } ) ?;
607+ let ret = unsafe {
608+ launch_l2_norm_batch_f32 (
609+ input_batch_d,
610+ num_samples,
611+ sample_size,
612+ * buffer. device_ptr_mut ( ) as * mut f32 ,
613+ stream,
614+ )
615+ } ;
616+ if ret != 0 {
617+ return Err ( MahoutError :: KernelLaunch ( format ! (
618+ "Norm reduction kernel failed with CUDA error code: {} ({})" ,
619+ ret,
620+ cuda_error_to_string( ret)
621+ ) ) ) ;
622+ }
623+ buffer
624+ } ;
625+ {
626+ crate :: profile_scope!( "GPU::NormValidation_f32" ) ;
627+ let host_inv_norms = device
628+ . dtoh_sync_copy ( & inv_norms_gpu)
629+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
630+ if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
631+ return Err ( MahoutError :: InvalidInput (
632+ "One or more samples have zero or invalid norm" . to_string ( ) ,
633+ ) ) ;
634+ }
635+ }
636+ {
637+ crate :: profile_scope!( "GPU::BatchKernelLaunch_f32" ) ;
638+ use cudarc:: driver:: DevicePtr ;
639+ let state_ptr = batch_state_vector. ptr_f32 ( ) . ok_or_else ( || {
640+ MahoutError :: InvalidInput (
641+ "Batch state vector precision mismatch (expected float32 buffer)" . to_string ( ) ,
642+ )
643+ } ) ?;
644+ let ret = unsafe {
645+ launch_amplitude_encode_batch_f32 (
646+ input_batch_d,
647+ state_ptr as * mut c_void ,
648+ * inv_norms_gpu. device_ptr ( ) as * const f32 ,
649+ num_samples,
650+ sample_size,
651+ state_len,
652+ stream,
653+ )
654+ } ;
655+ if ret != 0 {
656+ return Err ( MahoutError :: KernelLaunch ( format ! (
657+ "Batch kernel launch failed with CUDA error code: {} ({})" ,
658+ ret,
659+ cuda_error_to_string( ret)
660+ ) ) ) ;
661+ }
662+ }
663+ {
664+ crate :: profile_scope!( "GPU::Synchronize" ) ;
665+ sync_cuda_stream ( stream, "CUDA stream synchronize failed" ) ?;
666+ }
667+ Ok ( batch_state_vector)
668+ }
669+
460670 fn name ( & self ) -> & ' static str {
461671 "amplitude"
462672 }
0 commit comments