Skip to content

Commit e15105f

Browse files
Transurgeonclaude
andcommitted
Merge origin/main into parameter-support-v2
Resolves conflicts from main's workspace refactor (iwork/dwork -> work->), elementwise atom split (elementwise_univariate -> full_dom/restricted_dom), chain rule refactor, vstack addition, and numerical diff checker. Removed stale test_jacobian_composite_log tests (incompatible with main's chain rule refactor; covered by test_chain_rule_jacobian and numerical diff). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2 parents 61d657f + 7f671cb commit e15105f

90 files changed

Lines changed: 1603 additions & 955 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ if(NOT SKBUILD)
103103
add_executable(all_tests
104104
tests/all_tests.c
105105
tests/test_helpers.c
106+
tests/numerical_diff.c
106107
)
107108
target_link_libraries(all_tests dnlp_diff)
108109

include/affine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ expr *new_neg(expr *child);
2929

3030
expr *new_sum(expr *child, int axis);
3131
expr *new_hstack(expr **args, int n_args, int n_vars);
32+
expr *new_vstack(expr **args, int n_args, int n_vars);
3233
expr *new_promote(expr *child, int d1, int d2);
3334
expr *new_trace(expr *child);
3435

include/elementwise_full_dom.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef ELEMENTWISE_FULL_DOM_H
2+
#define ELEMENTWISE_FULL_DOM_H
3+
4+
#include "expr.h"
5+
6+
/* Helper function to initialize an elementwise expr
7+
* (can be used with derived types) */
8+
void init_elementwise(expr *node, expr *child);
9+
10+
expr *new_exp(expr *child);
11+
expr *new_sin(expr *child);
12+
expr *new_cos(expr *child);
13+
expr *new_sinh(expr *child);
14+
expr *new_tanh(expr *child);
15+
expr *new_asinh(expr *child);
16+
expr *new_logistic(expr *child);
17+
expr *new_power(expr *child, double p);
18+
expr *new_xexp(expr *child);
19+
expr *new_normal_cdf(expr *child);
20+
21+
/* the jacobian and wsum_hess for elementwise full domain
22+
atoms are always initialized in the same way and
23+
implement the chain rule in the same way */
24+
void jacobian_init_elementwise(expr *node);
25+
void eval_jacobian_elementwise(expr *node);
26+
void wsum_hess_init_elementwise(expr *node);
27+
void eval_wsum_hess_elementwise(expr *node, const double *w);
28+
expr *new_elementwise(expr *child);
29+
30+
/* no elementwise atoms are affine according to our
31+
convention, so we can have a common implementation */
32+
bool is_affine_elementwise(const expr *node);
33+
34+
#endif /* ELEMENTWISE_FULL_DOM_H */
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef ELEMENTWISE_RESTRICTED_DOM_H
2+
#define ELEMENTWISE_RESTRICTED_DOM_H
3+
4+
#include "expr.h"
5+
6+
/* Shared init functions for restricted domain atoms
7+
* (variable-child only, no linear operator support) */
8+
void jacobian_init_restricted(expr *node);
9+
void wsum_hess_init_restricted(expr *node);
10+
bool is_affine_restricted(const expr *node);
11+
expr *new_restricted(expr *child);
12+
13+
expr *new_log(expr *child);
14+
expr *new_entr(expr *child);
15+
expr *new_atanh(expr *child);
16+
expr *new_tan(expr *child);
17+
18+
#endif /* ELEMENTWISE_RESTRICTED_DOM_H */

include/elementwise_univariate.h

Lines changed: 0 additions & 54 deletions
This file was deleted.

include/expr.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,24 @@ typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double
3939
typedef bool (*is_affine_fn)(const struct expr *node);
4040
typedef void (*free_type_data_fn)(struct expr *node);
4141

42-
/* Base expression node structure - contains only common fields */
42+
/* Workspace for derivative computation */
43+
typedef struct
44+
{
45+
double *dwork;
46+
int *iwork;
47+
CSC_Matrix *jacobian_csc;
48+
int *csc_work; /* for CSR-CSC conversion */
49+
50+
/* jacobian_csc_filled is only used for affine functions to avoid redundant
51+
conversions. Could become relevant for non-affine functions if we start
52+
supporting common subexpressions on the Python side. */
53+
bool jacobian_csc_filled;
54+
double *local_jac_diag; /* cached f'(g(x)) diagonal */
55+
CSR_Matrix *hess_term1; /* Jg^T D Jg workspace */
56+
CSR_Matrix *hess_term2; /* child wsum_hess workspace */
57+
} Expr_Work;
58+
59+
/* Base expression node structure */
4360
typedef struct expr
4461
{
4562
// ------------------------------------------------------------------------
@@ -48,8 +65,6 @@ typedef struct expr
4865
int d1, d2, size, n_vars, refcount, var_id;
4966
struct expr *left;
5067
struct expr *right;
51-
double *dwork;
52-
int *iwork;
5368

5469
// ------------------------------------------------------------------------
5570
// oracle related quantities
@@ -70,6 +85,7 @@ typedef struct expr
7085
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
7186
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
7287
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
88+
Expr_Work *work; /* derivative workspace */
7389

7490
// name of node just for debugging - should be removed later
7591
char name[32];
@@ -83,6 +99,10 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
8399

84100
void free_expr(expr *node);
85101

102+
/* Initialize CSC form of the Jacobian from the CSR Jacobian.
103+
* Must be called after jacobian_init. */
104+
void jacobian_csc_init(expr *node);
105+
86106
/* Reference counting helpers */
87107
void expr_retain(expr *node);
88108

include/utils/CSR_Matrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ typedef struct CSR_Matrix
2828
/* constructors and destructors */
2929
CSR_Matrix *new_csr_matrix(int m, int n, int nnz);
3030
CSR_Matrix *new_csr(const CSR_Matrix *A);
31+
CSR_Matrix *new_csr_copy_sparsity(const CSR_Matrix *A);
3132
void free_csr_matrix(CSR_Matrix *matrix);
3233
void copy_csr_matrix(const CSR_Matrix *A, CSR_Matrix *C);
3334

src/affine/broadcast.c

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,10 @@ static void wsum_hess_init(expr *node)
191191
x->wsum_hess_init(x);
192192

193193
/* Same sparsity as child - weights get summed */
194-
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
195-
memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->wsum_hess->m + 1) * sizeof(int));
196-
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
194+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
197195

198196
/* allocate space for weight vector */
199-
node->dwork = malloc(node->size * sizeof(double));
197+
node->work->dwork = malloc(node->size * sizeof(double));
200198
}
201199

202200
static void eval_wsum_hess(expr *node, const double *w)
@@ -205,7 +203,7 @@ static void eval_wsum_hess(expr *node, const double *w)
205203
expr *x = node->left;
206204

207205
/* Zero out the work array first */
208-
memset(node->dwork, 0, x->size * sizeof(double));
206+
memset(node->work->dwork, 0, x->size * sizeof(double));
209207

210208
if (bcast->type == BROADCAST_ROW)
211209
{
@@ -214,7 +212,7 @@ static void eval_wsum_hess(expr *node, const double *w)
214212
{
215213
for (int i = 0; i < node->d1; i++)
216214
{
217-
node->dwork[j] += w[i + j * node->d1];
215+
node->work->dwork[j] += w[i + j * node->d1];
218216
}
219217
}
220218
}
@@ -225,21 +223,21 @@ static void eval_wsum_hess(expr *node, const double *w)
225223
{
226224
for (int i = 0; i < node->d1; i++)
227225
{
228-
node->dwork[i] += w[i + j * node->d1];
226+
node->work->dwork[i] += w[i + j * node->d1];
229227
}
230228
}
231229
}
232230
else
233231
{
234232
/* (1, 1) -> (m, n): scalar has m*n weights to sum */
235-
node->dwork[0] = 0.0;
233+
node->work->dwork[0] = 0.0;
236234
for (int k = 0; k < node->size; k++)
237235
{
238-
node->dwork[0] += w[k];
236+
node->work->dwork[0] += w[k];
239237
}
240238
}
241239

242-
x->eval_wsum_hess(x, node->dwork);
240+
x->eval_wsum_hess(x, node->work->dwork);
243241
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
244242
}
245243

src/affine/diag_vec.c

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,12 @@ static void wsum_hess_init(expr *node)
100100
x->wsum_hess_init(x);
101101

102102
/* workspace for extracting diagonal weights */
103-
node->dwork = (double *) calloc(x->size, sizeof(double));
103+
node->work->dwork = (double *) calloc(x->size, sizeof(double));
104104

105105
/* Copy child's Hessian structure (diag_vec is linear, so its own Hessian is
106106
* zero) */
107107
CSR_Matrix *Hx = x->wsum_hess;
108-
node->wsum_hess = new_csr_matrix(Hx->m, Hx->n, Hx->nnz);
109-
memcpy(node->wsum_hess->p, Hx->p, (Hx->m + 1) * sizeof(int));
110-
memcpy(node->wsum_hess->i, Hx->i, Hx->nnz * sizeof(int));
108+
node->wsum_hess = new_csr_copy_sparsity(Hx);
111109
}
112110

113111
static void eval_wsum_hess(expr *node, const double *w)
@@ -118,11 +116,11 @@ static void eval_wsum_hess(expr *node, const double *w)
118116
/* Extract weights from diagonal positions of w (which has n^2 elements) */
119117
for (int i = 0; i < n; i++)
120118
{
121-
node->dwork[i] = w[i * (n + 1)];
119+
node->work->dwork[i] = w[i * (n + 1)];
122120
}
123121

124122
/* Evaluate child's Hessian with extracted weights */
125-
x->eval_wsum_hess(x, node->dwork);
123+
x->eval_wsum_hess(x, node->work->dwork);
126124
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
127125
}
128126

src/affine/index.c

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ static void wsum_hess_init(expr *node)
104104
x->wsum_hess_init(x);
105105

106106
/* for setting weight vector to evaluate hessian of child */
107-
node->dwork = (double *) calloc(x->size, sizeof(double));
107+
node->work->dwork = (double *) calloc(x->size, sizeof(double));
108108

109109
/* in the implementation of eval_wsum_hess we evaluate the
110110
child's hessian with a weight vector that has w[i] = 0
@@ -113,9 +113,7 @@ static void wsum_hess_init(expr *node)
113113
structural zeros, but we do not try to exploit that sparsity
114114
right now. */
115115
CSR_Matrix *Hx = x->wsum_hess;
116-
node->wsum_hess = new_csr_matrix(Hx->m, Hx->n, Hx->nnz);
117-
memcpy(node->wsum_hess->p, Hx->p, (Hx->m + 1) * sizeof(int));
118-
memcpy(node->wsum_hess->i, Hx->i, Hx->nnz * sizeof(int));
116+
node->wsum_hess = new_csr_copy_sparsity(Hx);
119117
}
120118

121119
static void eval_wsum_hess(expr *node, const double *w)
@@ -126,23 +124,23 @@ static void eval_wsum_hess(expr *node, const double *w)
126124
if (idx->has_duplicates)
127125
{
128126
/* zero and accumulate for repeated indices */
129-
memset(node->dwork, 0, x->size * sizeof(double));
127+
memset(node->work->dwork, 0, x->size * sizeof(double));
130128
for (int i = 0; i < idx->n_idxs; i++)
131129
{
132-
node->dwork[idx->indices[i]] += w[i];
130+
node->work->dwork[idx->indices[i]] += w[i];
133131
}
134132
}
135133
else
136134
{
137135
/* direct write (no memset needed, no accumulation) */
138136
for (int i = 0; i < idx->n_idxs; i++)
139137
{
140-
node->dwork[idx->indices[i]] = w[i];
138+
node->work->dwork[idx->indices[i]] = w[i];
141139
}
142140
}
143141

144142
/* evalute hessian of child */
145-
x->eval_wsum_hess(x, node->dwork);
143+
x->eval_wsum_hess(x, node->work->dwork);
146144
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
147145
}
148146

0 commit comments

Comments
 (0)