Skip to content

Commit a91a823

Browse files
authored
Merge pull request #8 from mivertowski/claude/expand-cuda-intrinsics-015QNYA1vEFCNZoLMdLafWQ8
Expand CUDA intrinsics coverage to 120+ operations
2 parents 4e20f3f + f0cc426 commit a91a823

5 files changed

Lines changed: 2217 additions & 74 deletions

File tree

CLAUDE.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,15 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
192192
```
193193

194194
**DSL Features:**
195-
- Block/grid indices: `block_idx_x()`, `thread_idx_x()`, `block_dim_x()`, `grid_dim_x()`, etc.
195+
- Block/grid indices: `block_idx_x()`, `thread_idx_x()`, `block_dim_x()`, `grid_dim_x()`, `warp_size()`, etc.
196196
- Control flow: `if/else`, `match` → switch/case, early `return`
197197
- Loops: `for i in 0..n`, `while cond`, `loop` with `break`/`continue`
198-
- Stencil intrinsics: `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`, `pos.at(buf, dx, dy)`
198+
- Stencil intrinsics (2D): `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`, `pos.at(buf, dx, dy)`
199+
- Stencil intrinsics (3D): `pos.up(buf)`, `pos.down(buf)`, `pos.at(buf, dx, dy, dz)` for volumetric kernels
199200
- Shared memory: `__shared__` arrays and tiles with `SharedMemoryConfig`
200201
- Struct literals: `Point { x: 1.0, y: 2.0 }` → C compound literals
201202
- Reference expressions: `&arr[idx]` → pointer to element with automatic `->` operator for field access
202-
- 45+ GPU intrinsics (atomics, warp ops, sync, math)
203+
- 120+ GPU intrinsics across 13 categories (synchronization, atomics, math, trig, hyperbolic, exponential, classification, warp, bit manipulation, memory, special, index, timing)
203204

204205
**Ring Kernel Features:**
205206
- Persistent message loop with ControlBlock lifecycle management
@@ -282,7 +283,7 @@ Main crate (`ringkernel`) features:
282283
- ringkernel-core: 65 tests
283284
- ringkernel-cpu: 11 tests
284285
- ringkernel-cuda: 6 GPU execution tests
285-
- ringkernel-cuda-codegen: 143 tests (loops, shared memory, ring kernels, K2K, reference expressions)
286+
- ringkernel-cuda-codegen: 171 tests (loops, shared memory, ring kernels, K2K, reference expressions, 120+ GPU intrinsics)
286287
- ringkernel-wgpu-codegen: 50 tests (types, intrinsics, transpiler, validation)
287288
- ringkernel-derive: 14 macro tests
288289
- ringkernel-wavesim: 49 tests (including educational modes)

crates/ringkernel-cuda-codegen/README.md

Lines changed: 176 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Rust-to-CUDA transpiler for RingKernel GPU kernels.
77
This crate enables writing GPU kernels in a restricted Rust DSL and transpiling them to CUDA C code. It supports three kernel types:
88

99
1. **Global Kernels** - Standard CUDA `__global__` functions
10-
2. **Stencil Kernels** - Tile-based kernels with `GridPos` abstraction
10+
2. **Stencil Kernels** - Tile-based kernels with `GridPos` abstraction (2D and 3D)
1111
3. **Ring Kernels** - Persistent actor kernels with message loops
1212

1313
## Installation
@@ -39,11 +39,12 @@ let cuda_code = transpile_global_kernel(&func)?;
3939

4040
## Stencil Kernels
4141

42-
For grid-based computations with neighbor access:
42+
For grid-based computations with neighbor access (2D and 3D):
4343

4444
```rust
45-
use ringkernel_cuda_codegen::{transpile_stencil_kernel, StencilConfig};
45+
use ringkernel_cuda_codegen::{transpile_stencil_kernel, StencilConfig, Grid};
4646

47+
// 2D stencil
4748
let func: syn::ItemFn = parse_quote! {
4849
fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
4950
let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p)
@@ -53,10 +54,25 @@ let func: syn::ItemFn = parse_quote! {
5354
};
5455

5556
let config = StencilConfig::new("fdtd")
57+
.with_grid(Grid::Grid2D)
5658
.with_tile_size(16, 16)
5759
.with_halo(1);
5860

5961
let cuda_code = transpile_stencil_kernel(&func, &config)?;
62+
63+
// 3D stencil with up/down neighbors
64+
let func_3d: syn::ItemFn = parse_quote! {
65+
fn laplacian_3d(p: &[f32], out: &mut [f32], pos: GridPos) {
66+
let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p)
67+
+ pos.up(p) + pos.down(p) - 6.0 * p[pos.idx()];
68+
out[pos.idx()] = lap;
69+
}
70+
};
71+
72+
let config_3d = StencilConfig::new("laplacian")
73+
.with_grid(Grid::Grid3D)
74+
.with_tile_size(8, 8)
75+
.with_halo(1);
6076
```
6177

6278
## Ring Kernels
@@ -86,35 +102,149 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
86102
## DSL Reference
87103

88104
### Thread/Block Indices
89-
- `thread_idx_x()`, `thread_idx_y()`, `thread_idx_z()`
90-
- `block_idx_x()`, `block_idx_y()`, `block_idx_z()`
91-
- `block_dim_x()`, `block_dim_y()`, `block_dim_z()`
92-
- `grid_dim_x()`, `grid_dim_y()`, `grid_dim_z()`
105+
- `thread_idx_x()`, `thread_idx_y()`, `thread_idx_z()``threadIdx.x/y/z`
106+
- `block_idx_x()`, `block_idx_y()`, `block_idx_z()``blockIdx.x/y/z`
107+
- `block_dim_x()`, `block_dim_y()`, `block_dim_z()``blockDim.x/y/z`
108+
- `grid_dim_x()`, `grid_dim_y()`, `grid_dim_z()``gridDim.x/y/z`
109+
- `warp_size()``warpSize`
93110

94-
### Stencil Intrinsics
111+
### Stencil Intrinsics (2D)
95112
- `pos.idx()` - Linear index
96-
- `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`
113+
- `pos.north(buf)`, `pos.south(buf)` - Y-axis neighbors
114+
- `pos.east(buf)`, `pos.west(buf)` - X-axis neighbors
97115
- `pos.at(buf, dx, dy)` - Relative offset access
98116

99-
### Synchronization
100-
- `sync_threads()` - Block-level barrier
101-
- `thread_fence()` - Device memory fence
102-
- `thread_fence_block()` - Block memory fence
103-
104-
### Atomics
105-
- `atomic_add(ptr, val)`, `atomic_sub(ptr, val)`
106-
- `atomic_min(ptr, val)`, `atomic_max(ptr, val)`
107-
- `atomic_exchange(ptr, val)`, `atomic_cas(ptr, compare, val)`
117+
### Stencil Intrinsics (3D)
118+
- `pos.up(buf)`, `pos.down(buf)` - Z-axis neighbors
119+
- `pos.at(buf, dx, dy, dz)` - 3D relative offset access
108120

109-
### Math Functions
110-
- `sqrt()`, `abs()`, `floor()`, `ceil()`, `round()`
111-
- `sin()`, `cos()`, `tan()`, `exp()`, `log()`
112-
- `powf()`, `min()`, `max()`, `mul_add()`
121+
### Synchronization
122+
- `sync_threads()``__syncthreads()` - Block-level barrier
123+
- `sync_threads_count(pred)``__syncthreads_count()` - Count threads with predicate
124+
- `sync_threads_and(pred)``__syncthreads_and()` - AND of predicate
125+
- `sync_threads_or(pred)``__syncthreads_or()` - OR of predicate
126+
- `thread_fence()``__threadfence()` - Device memory fence
127+
- `thread_fence_block()``__threadfence_block()` - Block memory fence
128+
- `thread_fence_system()``__threadfence_system()` - System memory fence
129+
130+
### Atomic Operations (Integer)
131+
- `atomic_add(ptr, val)``atomicAdd`
132+
- `atomic_sub(ptr, val)``atomicSub`
133+
- `atomic_min(ptr, val)``atomicMin`
134+
- `atomic_max(ptr, val)``atomicMax`
135+
- `atomic_exchange(ptr, val)``atomicExch`
136+
- `atomic_cas(ptr, compare, val)``atomicCAS`
137+
- `atomic_and(ptr, val)``atomicAnd`
138+
- `atomic_or(ptr, val)``atomicOr`
139+
- `atomic_xor(ptr, val)``atomicXor`
140+
- `atomic_inc(ptr, val)``atomicInc` (increment with wrap)
141+
- `atomic_dec(ptr, val)``atomicDec` (decrement with wrap)
142+
143+
### Basic Math Functions
144+
- `sqrt()`, `rsqrt()` - Square root, reciprocal sqrt
145+
- `abs()`, `fabs()` - Absolute value
146+
- `floor()`, `ceil()`, `round()`, `trunc()` - Rounding
147+
- `fma()`, `mul_add()` - Fused multiply-add
148+
- `fmin()`, `fmax()` - Minimum, maximum
149+
- `fmod()`, `remainder()` - Modulo operations
150+
- `copysign()` - Copy sign
151+
- `cbrt()` - Cube root
152+
- `hypot()` - Hypotenuse
153+
154+
### Trigonometric Functions
155+
- `sin()`, `cos()`, `tan()` - Basic trig
156+
- `asin()`, `acos()`, `atan()`, `atan2()` - Inverse trig
157+
- `sincos()` - Combined sine and cosine
158+
- `sinpi()`, `cospi()` - Sin/cos of π*x
159+
160+
### Hyperbolic Functions
161+
- `sinh()`, `cosh()`, `tanh()` - Hyperbolic
162+
- `asinh()`, `acosh()`, `atanh()` - Inverse hyperbolic
163+
164+
### Exponential and Logarithmic Functions
165+
- `exp()`, `exp2()`, `exp10()`, `expm1()` - Exponentials
166+
- `log()`, `ln()`, `log2()`, `log10()`, `log1p()` - Logarithms
167+
- `pow()`, `powf()`, `powi()` - Power
168+
- `ldexp()`, `scalbn()` - Load/scale exponent
169+
- `ilogb()` - Extract exponent
170+
- `erf()`, `erfc()`, `erfinv()`, `erfcinv()` - Error functions
171+
- `lgamma()`, `tgamma()` - Gamma functions
172+
173+
### Classification Functions
174+
- `is_nan()`, `isnan()``isnan`
175+
- `is_infinite()`, `isinf()``isinf`
176+
- `is_finite()`, `isfinite()``isfinite`
177+
- `is_normal()`, `isnormal()``isnormal`
178+
- `signbit()` - Check sign bit
179+
- `nextafter()` - Next representable value
180+
- `fdim()` - Positive difference
113181

114182
### Warp Operations
115-
- `warp_shuffle(val, lane)`, `warp_shuffle_up(val, delta)`
116-
- `warp_shuffle_down(val, delta)`, `warp_shuffle_xor(val, mask)`
117-
- `warp_ballot(pred)`, `warp_all(pred)`, `warp_any(pred)`
183+
- `warp_active_mask()``__activemask()` - Active lane mask
184+
- `warp_shfl(mask, val, lane)``__shfl_sync` - Shuffle
185+
- `warp_shfl_up(mask, val, delta)``__shfl_up_sync`
186+
- `warp_shfl_down(mask, val, delta)``__shfl_down_sync`
187+
- `warp_shfl_xor(mask, val, lane_mask)``__shfl_xor_sync`
188+
- `warp_ballot(mask, pred)``__ballot_sync`
189+
- `warp_all(mask, pred)``__all_sync`
190+
- `warp_any(mask, pred)``__any_sync`
191+
192+
### Warp Match Operations (Volta+)
193+
- `warp_match_any(mask, val)``__match_any_sync`
194+
- `warp_match_all(mask, val)``__match_all_sync`
195+
196+
### Warp Reduce Operations (SM 8.0+)
197+
- `warp_reduce_add(mask, val)``__reduce_add_sync`
198+
- `warp_reduce_min(mask, val)``__reduce_min_sync`
199+
- `warp_reduce_max(mask, val)``__reduce_max_sync`
200+
- `warp_reduce_and(mask, val)``__reduce_and_sync`
201+
- `warp_reduce_or(mask, val)``__reduce_or_sync`
202+
- `warp_reduce_xor(mask, val)``__reduce_xor_sync`
203+
204+
### Bit Manipulation
205+
- `popc()`, `popcount()`, `count_ones()``__popc` - Population count
206+
- `clz()`, `leading_zeros()``__clz` - Count leading zeros
207+
- `ctz()`, `trailing_zeros()``__ffs - 1` - Count trailing zeros
208+
- `ffs()``__ffs` - Find first set
209+
- `brev()`, `reverse_bits()``__brev` - Bit reverse
210+
- `byte_perm()``__byte_perm` - Byte permutation
211+
- `funnel_shift_left()``__funnelshift_l`
212+
- `funnel_shift_right()``__funnelshift_r`
213+
214+
### Memory Operations
215+
- `ldg(ptr)`, `load_global(ptr)``__ldg` - Read-only cache load
216+
- `prefetch_l1(ptr)``__prefetch_l1` - L1 prefetch
217+
- `prefetch_l2(ptr)``__prefetch_l2` - L2 prefetch
218+
219+
### Special Functions
220+
- `rcp()`, `recip()``__frcp_rn` - Fast reciprocal
221+
- `fast_div()``__fdividef` - Fast division
222+
- `saturate()`, `clamp_01()``__saturatef` - Saturate to [0,1]
223+
- `j0()`, `j1()`, `jn()` - Bessel functions of first kind
224+
- `y0()`, `y1()`, `yn()` - Bessel functions of second kind
225+
- `normcdf()`, `normcdfinv()` - Normal CDF
226+
- `cyl_bessel_i0()`, `cyl_bessel_i1()` - Cylindrical Bessel functions
227+
228+
### Clock and Timing
229+
- `clock()``clock()` - 32-bit clock counter
230+
- `clock64()``clock64()` - 64-bit clock counter
231+
- `nanosleep(ns)``__nanosleep` - Sleep for nanoseconds
232+
233+
### RingContext Methods
234+
- `ctx.thread_id()``threadIdx.x`
235+
- `ctx.block_id()``blockIdx.x`
236+
- `ctx.global_thread_id()``(blockIdx.x * blockDim.x + threadIdx.x)`
237+
- `ctx.sync_threads()``__syncthreads()`
238+
- `ctx.lane_id()``(threadIdx.x % 32)`
239+
- `ctx.warp_id()``(threadIdx.x / 32)`
240+
241+
### Ring Kernel Intrinsics
242+
- `is_active()`, `should_terminate()`, `mark_terminated()`
243+
- `messages_processed()`, `input_queue_size()`, `output_queue_size()`
244+
- `input_queue_empty()`, `output_queue_empty()`, `enqueue_response(&resp)`
245+
- `hlc_tick()`, `hlc_update(ts)`, `hlc_now()` - HLC operations
246+
- `k2k_send(target, &msg)`, `k2k_try_recv()` - K2K messaging
247+
- `k2k_has_message()`, `k2k_peek()`, `k2k_pending_count()`
118248

119249
## Type Mapping
120250

@@ -130,13 +260,33 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
130260
| `&[T]` | `const T* __restrict__` |
131261
| `&mut [T]` | `T* __restrict__` |
132262

263+
## Intrinsic Count
264+
265+
The transpiler supports **120+ GPU intrinsics** across 13 categories:
266+
267+
| Category | Count | Examples |
268+
|----------|-------|----------|
269+
| Synchronization | 7 | `sync_threads`, `thread_fence` |
270+
| Atomics | 11 | `atomic_add`, `atomic_cas`, `atomic_and` |
271+
| Math | 16 | `sqrt`, `fma`, `cbrt`, `hypot` |
272+
| Trigonometric | 11 | `sin`, `asin`, `atan2`, `sincos` |
273+
| Hyperbolic | 6 | `sinh`, `asinh` |
274+
| Exponential | 18 | `exp`, `log2`, `erf`, `gamma` |
275+
| Classification | 8 | `isnan`, `isfinite`, `signbit` |
276+
| Warp | 16 | `warp_shfl`, `warp_reduce_add`, `warp_match_any` |
277+
| Bit Manipulation | 8 | `popc`, `clz`, `brev`, `funnel_shift_left` |
278+
| Memory | 3 | `ldg`, `prefetch_l1` |
279+
| Special | 13 | `rcp`, `saturate`, `normcdf` |
280+
| Index | 13 | `thread_idx_x`, `warp_size` |
281+
| Timing | 3 | `clock`, `clock64`, `nanosleep` |
282+
133283
## Testing
134284

135285
```bash
136286
cargo test -p ringkernel-cuda-codegen
137287
```
138288

139-
The crate includes 143 tests covering all kernel types and language features.
289+
The crate includes 171 tests covering all kernel types, intrinsics, and language features.
140290

141291
## License
142292

0 commit comments

Comments
 (0)