@@ -7,7 +7,7 @@ Rust-to-CUDA transpiler for RingKernel GPU kernels.
77This crate enables writing GPU kernels in a restricted Rust DSL and transpiling them to CUDA C code. It supports three kernel types:
88
991 . ** Global Kernels** - Standard CUDA ` __global__ ` functions
10- 2 . ** Stencil Kernels** - Tile-based kernels with ` GridPos ` abstraction
10+ 2 . ** Stencil Kernels** - Tile-based kernels with ` GridPos ` abstraction (2D and 3D)
11113 . ** Ring Kernels** - Persistent actor kernels with message loops
1212
1313## Installation
@@ -39,11 +39,12 @@ let cuda_code = transpile_global_kernel(&func)?;
3939
4040## Stencil Kernels
4141
42- For grid-based computations with neighbor access:
42+ For grid-based computations with neighbor access (2D and 3D) :
4343
4444``` rust
45- use ringkernel_cuda_codegen :: {transpile_stencil_kernel, StencilConfig };
45+ use ringkernel_cuda_codegen :: {transpile_stencil_kernel, StencilConfig , Grid };
4646
47+ // 2D stencil
4748let func : syn :: ItemFn = parse_quote! {
4849 fn fdtd (p : & [f32 ], p_prev : & mut [f32 ], c2 : f32 , pos : GridPos ) {
4950 let lap = pos . north (p ) + pos . south (p ) + pos . east (p ) + pos . west (p )
@@ -53,10 +54,25 @@ let func: syn::ItemFn = parse_quote! {
5354};
5455
5556let config = StencilConfig :: new (" fdtd" )
57+ . with_grid (Grid :: Grid2D )
5658 . with_tile_size (16 , 16 )
5759 . with_halo (1 );
5860
5961let cuda_code = transpile_stencil_kernel (& func , & config )? ;
62+
63+ // 3D stencil with up/down neighbors
64+ let func_3d : syn :: ItemFn = parse_quote! {
65+ fn laplacian_3d (p : & [f32 ], out : & mut [f32 ], pos : GridPos ) {
66+ let lap = pos . north (p ) + pos . south (p ) + pos . east (p ) + pos . west (p )
67+ + pos . up (p ) + pos . down (p ) - 6.0 * p [pos . idx ()];
68+ out [pos . idx ()] = lap ;
69+ }
70+ };
71+
72+ let config_3d = StencilConfig :: new (" laplacian" )
73+ . with_grid (Grid :: Grid3D )
74+ . with_tile_size (8 , 8 )
75+ . with_halo (1 );
6076```
6177
6278## Ring Kernels
@@ -86,35 +102,149 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
86102## DSL Reference
87103
88104### Thread/Block Indices
89- - ` thread_idx_x() ` , ` thread_idx_y() ` , ` thread_idx_z() `
90- - ` block_idx_x() ` , ` block_idx_y() ` , ` block_idx_z() `
91- - ` block_dim_x() ` , ` block_dim_y() ` , ` block_dim_z() `
92- - ` grid_dim_x() ` , ` grid_dim_y() ` , ` grid_dim_z() `
105+ - ` thread_idx_x() ` , ` thread_idx_y() ` , ` thread_idx_z() ` → ` threadIdx.x/y/z `
106+ - ` block_idx_x() ` , ` block_idx_y() ` , ` block_idx_z() ` → ` blockIdx.x/y/z `
107+ - ` block_dim_x() ` , ` block_dim_y() ` , ` block_dim_z() ` → ` blockDim.x/y/z `
108+ - ` grid_dim_x() ` , ` grid_dim_y() ` , ` grid_dim_z() ` → ` gridDim.x/y/z `
109+ - ` warp_size() ` → ` warpSize `
93110
94- ### Stencil Intrinsics
111+ ### Stencil Intrinsics (2D)
95112- ` pos.idx() ` - Linear index
96- - ` pos.north(buf) ` , ` pos.south(buf) ` , ` pos.east(buf) ` , ` pos.west(buf) `
113+ - ` pos.north(buf) ` , ` pos.south(buf) ` - Y-axis neighbors
114+ - ` pos.east(buf) ` , ` pos.west(buf) ` - X-axis neighbors
97115- ` pos.at(buf, dx, dy) ` - Relative offset access
98116
99- ### Synchronization
100- - ` sync_threads() ` - Block-level barrier
101- - ` thread_fence() ` - Device memory fence
102- - ` thread_fence_block() ` - Block memory fence
103-
104- ### Atomics
105- - ` atomic_add(ptr, val) ` , ` atomic_sub(ptr, val) `
106- - ` atomic_min(ptr, val) ` , ` atomic_max(ptr, val) `
107- - ` atomic_exchange(ptr, val) ` , ` atomic_cas(ptr, compare, val) `
117+ ### Stencil Intrinsics (3D)
118+ - ` pos.up(buf) ` , ` pos.down(buf) ` - Z-axis neighbors
119+ - ` pos.at(buf, dx, dy, dz) ` - 3D relative offset access
108120
109- ### Math Functions
110- - ` sqrt() ` , ` abs() ` , ` floor() ` , ` ceil() ` , ` round() `
111- - ` sin() ` , ` cos() ` , ` tan() ` , ` exp() ` , ` log() `
112- - ` powf() ` , ` min() ` , ` max() ` , ` mul_add() `
121+ ### Synchronization
122+ - ` sync_threads() ` → ` __syncthreads() ` - Block-level barrier
123+ - ` sync_threads_count(pred) ` → ` __syncthreads_count() ` - Count threads with predicate
124+ - ` sync_threads_and(pred) ` → ` __syncthreads_and() ` - AND of predicate
125+ - ` sync_threads_or(pred) ` → ` __syncthreads_or() ` - OR of predicate
126+ - ` thread_fence() ` → ` __threadfence() ` - Device memory fence
127+ - ` thread_fence_block() ` → ` __threadfence_block() ` - Block memory fence
128+ - ` thread_fence_system() ` → ` __threadfence_system() ` - System memory fence
129+
130+ ### Atomic Operations (Integer)
131+ - ` atomic_add(ptr, val) ` → ` atomicAdd `
132+ - ` atomic_sub(ptr, val) ` → ` atomicSub `
133+ - ` atomic_min(ptr, val) ` → ` atomicMin `
134+ - ` atomic_max(ptr, val) ` → ` atomicMax `
135+ - ` atomic_exchange(ptr, val) ` → ` atomicExch `
136+ - ` atomic_cas(ptr, compare, val) ` → ` atomicCAS `
137+ - ` atomic_and(ptr, val) ` → ` atomicAnd `
138+ - ` atomic_or(ptr, val) ` → ` atomicOr `
139+ - ` atomic_xor(ptr, val) ` → ` atomicXor `
140+ - ` atomic_inc(ptr, val) ` → ` atomicInc ` (increment with wrap)
141+ - ` atomic_dec(ptr, val) ` → ` atomicDec ` (decrement with wrap)
142+
143+ ### Basic Math Functions
144+ - ` sqrt() ` , ` rsqrt() ` - Square root, reciprocal sqrt
145+ - ` abs() ` , ` fabs() ` - Absolute value
146+ - ` floor() ` , ` ceil() ` , ` round() ` , ` trunc() ` - Rounding
147+ - ` fma() ` , ` mul_add() ` - Fused multiply-add
148+ - ` fmin() ` , ` fmax() ` - Minimum, maximum
149+ - ` fmod() ` , ` remainder() ` - Modulo operations
150+ - ` copysign() ` - Copy sign
151+ - ` cbrt() ` - Cube root
152+ - ` hypot() ` - Hypotenuse
153+
154+ ### Trigonometric Functions
155+ - ` sin() ` , ` cos() ` , ` tan() ` - Basic trig
156+ - ` asin() ` , ` acos() ` , ` atan() ` , ` atan2() ` - Inverse trig
157+ - ` sincos() ` - Combined sine and cosine
158+ - ` sinpi() ` , ` cospi() ` - Sin/cos of π* x
159+
160+ ### Hyperbolic Functions
161+ - ` sinh() ` , ` cosh() ` , ` tanh() ` - Hyperbolic
162+ - ` asinh() ` , ` acosh() ` , ` atanh() ` - Inverse hyperbolic
163+
164+ ### Exponential and Logarithmic Functions
165+ - ` exp() ` , ` exp2() ` , ` exp10() ` , ` expm1() ` - Exponentials
166+ - ` log() ` , ` ln() ` , ` log2() ` , ` log10() ` , ` log1p() ` - Logarithms
167+ - ` pow() ` , ` powf() ` , ` powi() ` - Power
168+ - ` ldexp() ` , ` scalbn() ` - Load/scale exponent
169+ - ` ilogb() ` - Extract exponent
170+ - ` erf() ` , ` erfc() ` , ` erfinv() ` , ` erfcinv() ` - Error functions
171+ - ` lgamma() ` , ` tgamma() ` - Gamma functions
172+
173+ ### Classification Functions
174+ - ` is_nan() ` , ` isnan() ` → ` isnan `
175+ - ` is_infinite() ` , ` isinf() ` → ` isinf `
176+ - ` is_finite() ` , ` isfinite() ` → ` isfinite `
177+ - ` is_normal() ` , ` isnormal() ` → ` isnormal `
178+ - ` signbit() ` - Check sign bit
179+ - ` nextafter() ` - Next representable value
180+ - ` fdim() ` - Positive difference
113181
114182### Warp Operations
115- - ` warp_shuffle(val, lane) ` , ` warp_shuffle_up(val, delta) `
116- - ` warp_shuffle_down(val, delta) ` , ` warp_shuffle_xor(val, mask) `
117- - ` warp_ballot(pred) ` , ` warp_all(pred) ` , ` warp_any(pred) `
183+ - ` warp_active_mask() ` → ` __activemask() ` - Active lane mask
184+ - ` warp_shfl(mask, val, lane) ` → ` __shfl_sync ` - Shuffle
185+ - ` warp_shfl_up(mask, val, delta) ` → ` __shfl_up_sync `
186+ - ` warp_shfl_down(mask, val, delta) ` → ` __shfl_down_sync `
187+ - ` warp_shfl_xor(mask, val, lane_mask) ` → ` __shfl_xor_sync `
188+ - ` warp_ballot(mask, pred) ` → ` __ballot_sync `
189+ - ` warp_all(mask, pred) ` → ` __all_sync `
190+ - ` warp_any(mask, pred) ` → ` __any_sync `
191+
192+ ### Warp Match Operations (Volta+)
193+ - ` warp_match_any(mask, val) ` → ` __match_any_sync `
194+ - ` warp_match_all(mask, val) ` → ` __match_all_sync `
195+
196+ ### Warp Reduce Operations (SM 8.0+)
197+ - ` warp_reduce_add(mask, val) ` → ` __reduce_add_sync `
198+ - ` warp_reduce_min(mask, val) ` → ` __reduce_min_sync `
199+ - ` warp_reduce_max(mask, val) ` → ` __reduce_max_sync `
200+ - ` warp_reduce_and(mask, val) ` → ` __reduce_and_sync `
201+ - ` warp_reduce_or(mask, val) ` → ` __reduce_or_sync `
202+ - ` warp_reduce_xor(mask, val) ` → ` __reduce_xor_sync `
203+
204+ ### Bit Manipulation
205+ - ` popc() ` , ` popcount() ` , ` count_ones() ` → ` __popc ` - Population count
206+ - ` clz() ` , ` leading_zeros() ` → ` __clz ` - Count leading zeros
207+ - ` ctz() ` , ` trailing_zeros() ` → ` __ffs - 1 ` - Count trailing zeros
208+ - ` ffs() ` → ` __ffs ` - Find first set
209+ - ` brev() ` , ` reverse_bits() ` → ` __brev ` - Bit reverse
210+ - ` byte_perm() ` → ` __byte_perm ` - Byte permutation
211+ - ` funnel_shift_left() ` → ` __funnelshift_l `
212+ - ` funnel_shift_right() ` → ` __funnelshift_r `
213+
214+ ### Memory Operations
215+ - ` ldg(ptr) ` , ` load_global(ptr) ` → ` __ldg ` - Read-only cache load
216+ - ` prefetch_l1(ptr) ` → ` __prefetch_l1 ` - L1 prefetch
217+ - ` prefetch_l2(ptr) ` → ` __prefetch_l2 ` - L2 prefetch
218+
219+ ### Special Functions
220+ - ` rcp() ` , ` recip() ` → ` __frcp_rn ` - Fast reciprocal
221+ - ` fast_div() ` → ` __fdividef ` - Fast division
222+ - ` saturate() ` , ` clamp_01() ` → ` __saturatef ` - Saturate to [ 0,1]
223+ - ` j0() ` , ` j1() ` , ` jn() ` - Bessel functions of first kind
224+ - ` y0() ` , ` y1() ` , ` yn() ` - Bessel functions of second kind
225+ - ` normcdf() ` , ` normcdfinv() ` - Normal CDF
226+ - ` cyl_bessel_i0() ` , ` cyl_bessel_i1() ` - Cylindrical Bessel functions
227+
228+ ### Clock and Timing
229+ - ` clock() ` → ` clock() ` - 32-bit clock counter
230+ - ` clock64() ` → ` clock64() ` - 64-bit clock counter
231+ - ` nanosleep(ns) ` → ` __nanosleep ` - Sleep for nanoseconds
232+
233+ ### RingContext Methods
234+ - ` ctx.thread_id() ` → ` threadIdx.x `
235+ - ` ctx.block_id() ` → ` blockIdx.x `
236+ - ` ctx.global_thread_id() ` → ` (blockIdx.x * blockDim.x + threadIdx.x) `
237+ - ` ctx.sync_threads() ` → ` __syncthreads() `
238+ - ` ctx.lane_id() ` → ` (threadIdx.x % 32) `
239+ - ` ctx.warp_id() ` → ` (threadIdx.x / 32) `
240+
241+ ### Ring Kernel Intrinsics
242+ - ` is_active() ` , ` should_terminate() ` , ` mark_terminated() `
243+ - ` messages_processed() ` , ` input_queue_size() ` , ` output_queue_size() `
244+ - ` input_queue_empty() ` , ` output_queue_empty() ` , ` enqueue_response(&resp) `
245+ - ` hlc_tick() ` , ` hlc_update(ts) ` , ` hlc_now() ` - HLC operations
246+ - ` k2k_send(target, &msg) ` , ` k2k_try_recv() ` - K2K messaging
247+ - ` k2k_has_message() ` , ` k2k_peek() ` , ` k2k_pending_count() `
118248
119249## Type Mapping
120250
@@ -130,13 +260,33 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
130260| ` &[T] ` | ` const T* __restrict__ ` |
131261| ` &mut [T] ` | ` T* __restrict__ ` |
132262
263+ ## Intrinsic Count
264+
265+ The transpiler supports ** 120+ GPU intrinsics** across 13 categories:
266+
267+ | Category | Count | Examples |
268+ | ----------| -------| ----------|
269+ | Synchronization | 7 | ` sync_threads ` , ` thread_fence ` |
270+ | Atomics | 11 | ` atomic_add ` , ` atomic_cas ` , ` atomic_and ` |
271+ | Math | 16 | ` sqrt ` , ` fma ` , ` cbrt ` , ` hypot ` |
272+ | Trigonometric | 11 | ` sin ` , ` asin ` , ` atan2 ` , ` sincos ` |
273+ | Hyperbolic | 6 | ` sinh ` , ` asinh ` |
274+ | Exponential | 18 | ` exp ` , ` log2 ` , ` erf ` , ` gamma ` |
275+ | Classification | 8 | ` isnan ` , ` isfinite ` , ` signbit ` |
276+ | Warp | 16 | ` warp_shfl ` , ` warp_reduce_add ` , ` warp_match_any ` |
277+ | Bit Manipulation | 8 | ` popc ` , ` clz ` , ` brev ` , ` funnel_shift_left ` |
278+ | Memory | 3 | ` ldg ` , ` prefetch_l1 ` |
279+ | Special | 13 | ` rcp ` , ` saturate ` , ` normcdf ` |
280+ | Index | 13 | ` thread_idx_x ` , ` warp_size ` |
281+ | Timing | 3 | ` clock ` , ` clock64 ` , ` nanosleep ` |
282+
133283## Testing
134284
135285``` bash
136286cargo test -p ringkernel-cuda-codegen
137287```
138288
139- The crate includes 143 tests covering all kernel types and language features.
289+ The crate includes 171 tests covering all kernel types, intrinsics, and language features.
140290
141291## License
142292
0 commit comments