Skip to content

Commit f30f6f1

Browse files
Transurgeonclaude
andcommitted
Add vstack atom implementation plan
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7e7875d commit f30f6f1

1 file changed

Lines changed: 112 additions & 0 deletions

File tree

docs/vstack-plan.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Plan: Implement `vstack` atom
2+
3+
## Context
4+
5+
The engine needs a `vstack` (vertical stack) operation that concatenates expressions along the row dimension. Unlike `hstack` — where column-major storage means output values are a simple flat concatenation of children — `vstack` interleaves children's data across columns, making every operation (forward, Jacobian, Hessian) require row-level remapping.
6+
7+
## Core challenge: column-major interleaving
8+
9+
For `vstack` of children with shapes `(r_0, d2), (r_1, d2), ...`:
10+
- Output shape: `(R, d2)` where `R = sum(r_i)`
11+
- Column `j` of the output = `[child_0 col j, child_1 col j, ...]` stacked
12+
- Output flat index `k` maps to: `global_row = k % R`, `col = k / R`
13+
- Finding which child owns `global_row` requires a precomputed lookup
14+
15+
## Implementation
16+
17+
### 1. Type-specific struct (`include/subexpr.h`)
18+
19+
```c
20+
typedef struct vstack_expr
21+
{
22+
expr base;
23+
expr **args;
24+
int n_args;
25+
int *row_offsets; /* row_offsets[i] = sum of args[0..i-1]->d1 */
26+
int *row_to_child; /* length R: maps global row -> child index */
27+
int *row_to_local; /* length R: maps global row -> local row in child */
28+
CSR_Matrix *CSR_work; /* for Hessian accumulation */
29+
} vstack_expr;
30+
```
31+
32+
Precompute `row_to_child[R]` and `row_to_local[R]` once in the constructor. These let every later operation do O(1) lookups instead of scanning children.
33+
34+
### 2. Constructor (`src/affine/vstack.c``new_vstack`)
35+
36+
- Assert all children have same `d2`
37+
- Compute `R = sum(r_i)`, output is `(R, d2)`
38+
- Allocate `vstack_expr`, call `init_expr`
39+
- Build `row_offsets`, `row_to_child`, `row_to_local` arrays
40+
- Retain all children
41+
42+
### 3. Forward pass
43+
44+
Loop over `d2` columns; within each column, loop over children and `memcpy` each child's column slice:
45+
46+
```
47+
offset = 0
48+
for j in 0..d2-1:
49+
for i in 0..n_args-1:
50+
memcpy(value + offset, args[i]->value + j * r_i, r_i * sizeof(double))
51+
offset += r_i
52+
```
53+
54+
### 4. Jacobian init
55+
56+
The output Jacobian is `(size, n_vars)` in CSR. Each output row `k` corresponds to a specific child Jacobian row. Iterate output rows in order, use the precomputed mapping to copy column indices from the right child row:
57+
58+
```
59+
for k in 0..size-1:
60+
global_row = k % R
61+
col = k / R
62+
child_idx = row_to_child[global_row]
63+
local_row = row_to_local[global_row]
64+
child_flat = col * args[child_idx]->d1 + local_row
65+
copy column indices from child's Jacobian row child_flat
66+
set J->p[k+1]
67+
```
68+
69+
### 5. eval_jacobian
70+
71+
Same loop as jacobian_init but copies values (`x`) instead of column indices (`i`).
72+
73+
### 6. Hessian (wsum_hess_init / eval)
74+
75+
**Sparsity init**: Identical to hstack — iteratively union children's Hessian patterns using `sum_csr_matrices_fill_sparsity`.
76+
77+
**Eval**: The weight vector `w` has one entry per output element. Each child needs a gathered weight vector. Allocate `dwork` of size `max(child->size)`. For each child, gather its weights:
78+
79+
```
80+
for j in 0..d2-1:
81+
for r in 0..r_i-1:
82+
dwork[j * r_i + r] = w[j * R + row_offsets[i] + r]
83+
```
84+
85+
Then call `child->eval_wsum_hess(child, dwork)` and accumulate with `sum_csr_matrices_fill_values`. This follows the same pattern transpose uses for weight permutation (`transpose.c:111-117`).
86+
87+
### 7. is_affine / free_type_data
88+
89+
Same pattern as hstack — check all children / free all children + auxiliary arrays.
90+
91+
## Files to modify
92+
93+
| File | Change |
94+
|------|--------|
95+
| `include/subexpr.h` | Add `vstack_expr` struct |
96+
| `include/affine.h` | Declare `new_vstack` |
97+
| `src/affine/vstack.c` | New file — full implementation |
98+
| `src/affine/hstack.c` | Reference only (pattern to follow) |
99+
| `src/affine/transpose.c` | Reference only (weight permutation pattern) |
100+
| `tests/forward_pass/affine/test_vstack.h` | New — forward pass tests |
101+
| `tests/jacobian_tests/test_vstack.h` | New — Jacobian tests |
102+
| `tests/wsum_hess/test_vstack.h` | New — Hessian tests |
103+
| `tests/all_tests.c` | Include new test headers, add `mu_run_test` calls |
104+
105+
## Verification
106+
107+
1. `cmake -B build -S . -DCMAKE_BUILD_TYPE=Debug && cmake --build build`
108+
2. `./build/all_tests` — all existing + new vstack tests pass
109+
3. Test cases should cover:
110+
- Vectors: `vstack([log(x), exp(x)])` where x is `(3,1)` — output `(6,1)` (degenerate case: behaves like hstack for vectors)
111+
- Matrices: `vstack([log(x), exp(y)])` where x is `(2,3)`, y is `(1,3)` — output `(3,3)` — verifies interleaving
112+
- Hessian with shared variables: `vstack([log(x), exp(x)])` — verifies weight gathering and Hessian accumulation

0 commit comments

Comments
 (0)