|
| 1 | +# Plan: Implement `vstack` atom |
| 2 | + |
| 3 | +## Context |
| 4 | + |
| 5 | +The engine needs a `vstack` (vertical stack) operation that concatenates expressions along the row dimension. Unlike `hstack` — where column-major storage means output values are a simple flat concatenation of children — `vstack` interleaves children's data across columns, making every operation (forward, Jacobian, Hessian) require row-level remapping. |
| 6 | + |
| 7 | +## Core challenge: column-major interleaving |
| 8 | + |
| 9 | +For `vstack` of children with shapes `(r_0, d2), (r_1, d2), ...`: |
| 10 | +- Output shape: `(R, d2)` where `R = sum(r_i)` |
| 11 | +- Column `j` of the output = `[child_0 col j, child_1 col j, ...]` stacked |
| 12 | +- Output flat index `k` maps to: `global_row = k % R`, `col = k / R` |
| 13 | +- Finding which child owns `global_row` requires a precomputed lookup |
| 14 | + |
| 15 | +## Implementation |
| 16 | + |
| 17 | +### 1. Type-specific struct (`include/subexpr.h`) |
| 18 | + |
| 19 | +```c |
| 20 | +typedef struct vstack_expr |
| 21 | +{ |
| 22 | + expr base; |
| 23 | + expr **args; |
| 24 | + int n_args; |
| 25 | + int *row_offsets; /* row_offsets[i] = sum of args[0..i-1]->d1 */ |
| 26 | + int *row_to_child; /* length R: maps global row -> child index */ |
| 27 | + int *row_to_local; /* length R: maps global row -> local row in child */ |
| 28 | + CSR_Matrix *CSR_work; /* for Hessian accumulation */ |
| 29 | +} vstack_expr; |
| 30 | +``` |
| 31 | + |
| 32 | +Precompute `row_to_child[R]` and `row_to_local[R]` once in the constructor. These let every later operation do O(1) lookups instead of scanning children. |
| 33 | + |
| 34 | +### 2. Constructor (`src/affine/vstack.c` — `new_vstack`) |
| 35 | + |
| 36 | +- Assert all children have same `d2` |
| 37 | +- Compute `R = sum(r_i)`, output is `(R, d2)` |
| 38 | +- Allocate `vstack_expr`, call `init_expr` |
| 39 | +- Build `row_offsets`, `row_to_child`, `row_to_local` arrays |
| 40 | +- Retain all children |
| 41 | + |
| 42 | +### 3. Forward pass |
| 43 | + |
| 44 | +Loop over `d2` columns; within each column, loop over children and `memcpy` each child's column slice: |
| 45 | + |
| 46 | +``` |
| 47 | +offset = 0 |
| 48 | +for j in 0..d2-1: |
| 49 | + for i in 0..n_args-1: |
| 50 | + memcpy(value + offset, args[i]->value + j * r_i, r_i * sizeof(double)) |
| 51 | + offset += r_i |
| 52 | +``` |
| 53 | + |
| 54 | +### 4. Jacobian init |
| 55 | + |
| 56 | +The output Jacobian is `(size, n_vars)` in CSR. Each output row `k` corresponds to a specific child Jacobian row. Iterate output rows in order, use the precomputed mapping to copy column indices from the right child row: |
| 57 | + |
| 58 | +``` |
| 59 | +for k in 0..size-1: |
| 60 | + global_row = k % R |
| 61 | + col = k / R |
| 62 | + child_idx = row_to_child[global_row] |
| 63 | + local_row = row_to_local[global_row] |
| 64 | + child_flat = col * args[child_idx]->d1 + local_row |
| 65 | + copy column indices from child's Jacobian row child_flat |
| 66 | + set J->p[k+1] |
| 67 | +``` |
| 68 | + |
| 69 | +### 5. eval_jacobian |
| 70 | + |
| 71 | +Same loop as jacobian_init but copies values (`x`) instead of column indices (`i`). |
| 72 | + |
| 73 | +### 6. Hessian (wsum_hess_init / eval) |
| 74 | + |
| 75 | +**Sparsity init**: Identical to hstack — iteratively union children's Hessian patterns using `sum_csr_matrices_fill_sparsity`. |
| 76 | + |
| 77 | +**Eval**: The weight vector `w` has one entry per output element. Each child needs a gathered weight vector. Allocate `dwork` of size `max(child->size)`. For each child, gather its weights: |
| 78 | + |
| 79 | +``` |
| 80 | +for j in 0..d2-1: |
| 81 | + for r in 0..r_i-1: |
| 82 | + dwork[j * r_i + r] = w[j * R + row_offsets[i] + r] |
| 83 | +``` |
| 84 | + |
| 85 | +Then call `child->eval_wsum_hess(child, dwork)` and accumulate with `sum_csr_matrices_fill_values`. This follows the same pattern transpose uses for weight permutation (`transpose.c:111-117`). |
| 86 | + |
| 87 | +### 7. is_affine / free_type_data |
| 88 | + |
| 89 | +Same pattern as hstack — check all children / free all children + auxiliary arrays. |
| 90 | + |
| 91 | +## Files to modify |
| 92 | + |
| 93 | +| File | Change | |
| 94 | +|------|--------| |
| 95 | +| `include/subexpr.h` | Add `vstack_expr` struct | |
| 96 | +| `include/affine.h` | Declare `new_vstack` | |
| 97 | +| `src/affine/vstack.c` | New file — full implementation | |
| 98 | +| `src/affine/hstack.c` | Reference only (pattern to follow) | |
| 99 | +| `src/affine/transpose.c` | Reference only (weight permutation pattern) | |
| 100 | +| `tests/forward_pass/affine/test_vstack.h` | New — forward pass tests | |
| 101 | +| `tests/jacobian_tests/test_vstack.h` | New — Jacobian tests | |
| 102 | +| `tests/wsum_hess/test_vstack.h` | New — Hessian tests | |
| 103 | +| `tests/all_tests.c` | Include new test headers, add `mu_run_test` calls | |
| 104 | + |
| 105 | +## Verification |
| 106 | + |
| 107 | +1. `cmake -B build -S . -DCMAKE_BUILD_TYPE=Debug && cmake --build build` |
| 108 | +2. `./build/all_tests` — all existing + new vstack tests pass |
| 109 | +3. Test cases should cover: |
| 110 | + - Vectors: `vstack([log(x), exp(x)])` where x is `(3,1)` — output `(6,1)` (degenerate case: behaves like hstack for vectors) |
| 111 | + - Matrices: `vstack([log(x), exp(y)])` where x is `(2,3)`, y is `(1,3)` — output `(3,3)` — verifies interleaving |
| 112 | + - Hessian with shared variables: `vstack([log(x), exp(x)])` — verifies weight gathering and Hessian accumulation |
0 commit comments