Skip to content

Commit 60ea4b0

Browse files
authored
[WIP] chain rule for jacobian in multiply (#64)
* function for random matrices and two tessts * run forammter * free matrices * add guard for initializing jacobian or hessian twice * run formatter
1 parent 22bf057 commit 60ea4b0

88 files changed

Lines changed: 438 additions & 288 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/expr.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ typedef struct expr
7373
CSR_Matrix *jacobian;
7474
CSR_Matrix *wsum_hess;
7575
forward_fn forward;
76-
jacobian_init_fn jacobian_init;
77-
wsum_hess_init_fn wsum_hess_init;
76+
jacobian_init_fn jacobian_init_impl;
77+
wsum_hess_init_fn wsum_hess_init_impl;
7878
eval_jacobian_fn eval_jacobian;
7979
wsum_hess_fn eval_wsum_hess;
8080

@@ -99,6 +99,11 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
9999

100100
void free_expr(expr *node);
101101

102+
/* Guarded init: skips if already initialized (safe for DAGs
103+
* where a node may be visited through multiple parents). */
104+
void jacobian_init(expr *node);
105+
void wsum_hess_init(expr *node);
106+
102107
/* Initialize CSC form of the Jacobian from the CSR Jacobian.
103108
* Must be called after jacobian_init. */
104109
void jacobian_csc_init(expr *node);

src/affine/add.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ static void forward(expr *node, const double *u)
3434
}
3535
}
3636

37-
static void jacobian_init(expr *node)
37+
static void jacobian_init_impl(expr *node)
3838
{
3939
/* initialize children's jacobians */
40-
node->left->jacobian_init(node->left);
41-
node->right->jacobian_init(node->right);
40+
jacobian_init(node->left);
41+
jacobian_init(node->right);
4242

4343
/* we never have to store more than the sum of children's nnz */
4444
int nnz_max = node->left->jacobian->nnz + node->right->jacobian->nnz;
@@ -60,11 +60,11 @@ static void eval_jacobian(expr *node)
6060
node->jacobian);
6161
}
6262

63-
static void wsum_hess_init(expr *node)
63+
static void wsum_hess_init_impl(expr *node)
6464
{
6565
/* initialize children's wsum_hess */
66-
node->left->wsum_hess_init(node->left);
67-
node->right->wsum_hess_init(node->right);
66+
wsum_hess_init(node->left);
67+
wsum_hess_init(node->right);
6868

6969
/* we never have to store more than the sum of children's nnz */
7070
int nnz_max = node->left->wsum_hess->nnz + node->right->wsum_hess->nnz;
@@ -95,8 +95,8 @@ expr *new_add(expr *left, expr *right)
9595
{
9696
assert(left->d1 == right->d1 && left->d2 == right->d2);
9797
expr *node = (expr *) calloc(1, sizeof(expr));
98-
init_expr(node, left->d1, left->d2, left->n_vars, forward, jacobian_init,
99-
eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL);
98+
init_expr(node, left->d1, left->d2, left->n_vars, forward, jacobian_init_impl,
99+
eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
100100
node->left = left;
101101
node->right = right;
102102
expr_retain(left);

src/affine/broadcast.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ static void forward(expr *node, const double *u)
6666
}
6767
}
6868

69-
static void jacobian_init(expr *node)
69+
static void jacobian_init_impl(expr *node)
7070
{
7171
expr *x = node->left;
72-
x->jacobian_init(x);
72+
jacobian_init(x);
7373
broadcast_expr *bcast = (broadcast_expr *) node;
7474
int total_nnz;
7575

@@ -185,10 +185,10 @@ static void eval_jacobian(expr *node)
185185
}
186186
}
187187

188-
static void wsum_hess_init(expr *node)
188+
static void wsum_hess_init_impl(expr *node)
189189
{
190190
expr *x = node->left;
191-
x->wsum_hess_init(x);
191+
wsum_hess_init(x);
192192

193193
/* Same sparsity as child - weights get summed */
194194
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
@@ -279,8 +279,8 @@ expr *new_broadcast(expr *child, int d1, int d2)
279279
// --------------------------------------------------------------------------
280280
// initialize the rest of the expression
281281
// --------------------------------------------------------------------------
282-
init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian,
283-
is_affine, wsum_hess_init, eval_wsum_hess, NULL);
282+
init_expr(node, d1, d2, child->n_vars, forward, jacobian_init_impl,
283+
eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
284284
node->left = child;
285285
expr_retain(child);
286286
bcast->type = type;

src/affine/const_scalar_mult.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ static void forward(expr *node, const double *u)
3939
}
4040
}
4141

42-
static void jacobian_init(expr *node)
42+
static void jacobian_init_impl(expr *node)
4343
{
4444
expr *x = node->left;
4545

4646
/* initialize child jacobian */
47-
x->jacobian_init(x);
47+
jacobian_init(x);
4848

4949
/* same sparsity as child */
5050
node->jacobian = new_csr_copy_sparsity(x->jacobian);
@@ -65,12 +65,12 @@ static void eval_jacobian(expr *node)
6565
}
6666
}
6767

68-
static void wsum_hess_init(expr *node)
68+
static void wsum_hess_init_impl(expr *node)
6969
{
7070
expr *x = node->left;
7171

7272
/* initialize child's weighted Hessian */
73-
x->wsum_hess_init(x);
73+
wsum_hess_init(x);
7474

7575
/* same sparsity as child */
7676
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
@@ -100,8 +100,8 @@ expr *new_const_scalar_mult(double a, expr *child)
100100
(const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr));
101101
expr *node = &mult_node->base;
102102

103-
init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init,
104-
eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL);
103+
init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init_impl,
104+
eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
105105
node->left = child;
106106
mult_node->a = a;
107107
expr_retain(child);

src/affine/const_vector_mult.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ static void forward(expr *node, const double *u)
3838
}
3939
}
4040

41-
static void jacobian_init(expr *node)
41+
static void jacobian_init_impl(expr *node)
4242
{
4343
expr *x = node->left;
4444

4545
/* initialize child jacobian */
46-
x->jacobian_init(x);
46+
jacobian_init(x);
4747

4848
/* same sparsity as child */
4949
node->jacobian = new_csr_copy_sparsity(x->jacobian);
@@ -67,12 +67,12 @@ static void eval_jacobian(expr *node)
6767
}
6868
}
6969

70-
static void wsum_hess_init(expr *node)
70+
static void wsum_hess_init_impl(expr *node)
7171
{
7272
expr *x = node->left;
7373

7474
/* initialize child's weighted Hessian */
75-
x->wsum_hess_init(x);
75+
wsum_hess_init(x);
7676

7777
/* same sparsity as child */
7878
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
@@ -115,8 +115,8 @@ expr *new_const_vector_mult(const double *a, expr *child)
115115
(const_vector_mult_expr *) calloc(1, sizeof(const_vector_mult_expr));
116116
expr *node = &vnode->base;
117117

118-
init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init,
119-
eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess,
118+
init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init_impl,
119+
eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess,
120120
free_type_data);
121121
node->left = child;
122122
expr_retain(child);

src/affine/constant.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ static void forward(expr *node, const double *u)
2626
(void) u;
2727
}
2828

29-
static void jacobian_init(expr *node)
29+
static void jacobian_init_impl(expr *node)
3030
{
3131
/* Constant jacobian is all zeros: size x n_vars with 0 nonzeros.
3232
* new_csr_matrix uses calloc for row pointers, so they're already 0. */
@@ -39,7 +39,7 @@ static void eval_jacobian(expr *node)
3939
(void) node;
4040
}
4141

42-
static void wsum_hess_init(expr *node)
42+
static void wsum_hess_init_impl(expr *node)
4343
{
4444
/* Constant Hessian is all zeros: n_vars x n_vars with 0 nonzeros. */
4545
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
@@ -61,8 +61,8 @@ static bool is_affine(const expr *node)
6161
expr *new_constant(int d1, int d2, int n_vars, const double *values)
6262
{
6363
expr *node = (expr *) calloc(1, sizeof(expr));
64-
init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine,
65-
wsum_hess_init, eval_wsum_hess, NULL);
64+
init_expr(node, d1, d2, n_vars, forward, jacobian_init_impl, eval_jacobian,
65+
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
6666
memcpy(node->value, values, node->size * sizeof(double));
6767

6868
return node;

src/affine/diag_vec.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ static void forward(expr *node, const double *u)
4444
}
4545
}
4646

47-
static void jacobian_init(expr *node)
47+
static void jacobian_init_impl(expr *node)
4848
{
4949
expr *x = node->left;
5050
int n = x->size;
51-
x->jacobian_init(x);
51+
jacobian_init(x);
5252

5353
CSR_Matrix *Jx = x->jacobian;
5454
CSR_Matrix *J = new_csr_matrix(node->size, node->n_vars, Jx->nnz);
@@ -92,12 +92,12 @@ static void eval_jacobian(expr *node)
9292
}
9393
}
9494

95-
static void wsum_hess_init(expr *node)
95+
static void wsum_hess_init_impl(expr *node)
9696
{
9797
expr *x = node->left;
9898

9999
/* initialize child's wsum_hess */
100-
x->wsum_hess_init(x);
100+
wsum_hess_init(x);
101101

102102
/* workspace for extracting diagonal weights */
103103
node->work->dwork = (double *) calloc(x->size, sizeof(double));
@@ -137,8 +137,8 @@ expr *new_diag_vec(expr *child)
137137
/* n is the number of elements (works for both row and column vectors) */
138138
int n = child->size;
139139
expr *node = (expr *) calloc(1, sizeof(expr));
140-
init_expr(node, n, n, child->n_vars, forward, jacobian_init, eval_jacobian,
141-
is_affine, wsum_hess_init, eval_wsum_hess, NULL);
140+
init_expr(node, n, n, child->n_vars, forward, jacobian_init_impl, eval_jacobian,
141+
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
142142
node->left = child;
143143
expr_retain(child);
144144

src/affine/hstack.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ static void forward(expr *node, const double *u)
4242
}
4343
}
4444

45-
static void jacobian_init(expr *node)
45+
static void jacobian_init_impl(expr *node)
4646
{
4747
hstack_expr *hnode = (hstack_expr *) node;
4848

@@ -51,7 +51,7 @@ static void jacobian_init(expr *node)
5151
for (int i = 0; i < hnode->n_args; i++)
5252
{
5353
assert(hnode->args[i] != NULL);
54-
hnode->args[i]->jacobian_init(hnode->args[i]);
54+
jacobian_init(hnode->args[i]);
5555
nnz += hnode->args[i]->jacobian->nnz;
5656
}
5757

@@ -100,14 +100,14 @@ static void eval_jacobian(expr *node)
100100
}
101101
}
102102

103-
static void wsum_hess_init(expr *node)
103+
static void wsum_hess_init_impl(expr *node)
104104
{
105105
/* initialize children's hessians */
106106
hstack_expr *hnode = (hstack_expr *) node;
107107
int nnz = 0;
108108
for (int i = 0; i < hnode->n_args; i++)
109109
{
110-
hnode->args[i]->wsum_hess_init(hnode->args[i]);
110+
wsum_hess_init(hnode->args[i]);
111111
nnz += hnode->args[i]->wsum_hess->nnz;
112112
}
113113

@@ -187,8 +187,9 @@ expr *new_hstack(expr **args, int n_args, int n_vars)
187187
/* Allocate the type-specific struct */
188188
hstack_expr *hnode = (hstack_expr *) calloc(1, sizeof(hstack_expr));
189189
expr *node = &hnode->base;
190-
init_expr(node, args[0]->d1, d2, n_vars, forward, jacobian_init, eval_jacobian,
191-
is_affine, wsum_hess_init, wsum_hess_eval, free_type_data);
190+
init_expr(node, args[0]->d1, d2, n_vars, forward, jacobian_init_impl,
191+
eval_jacobian, is_affine, wsum_hess_init_impl, wsum_hess_eval,
192+
free_type_data);
192193

193194
/* Set type-specific fields (deep copy args array) */
194195
hnode->args = (expr **) calloc(n_args, sizeof(expr *));

src/affine/index.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ static void forward(expr *node, const double *u)
5757
}
5858
}
5959

60-
static void jacobian_init(expr *node)
60+
static void jacobian_init_impl(expr *node)
6161
{
6262
expr *x = node->left;
6363
index_expr *idx = (index_expr *) node;
64-
x->jacobian_init(x);
64+
jacobian_init(x);
6565

6666
CSR_Matrix *Jx = x->jacobian;
6767
CSR_Matrix *J = new_csr_matrix(node->size, node->n_vars, Jx->nnz);
@@ -96,12 +96,12 @@ static void eval_jacobian(expr *node)
9696
}
9797
}
9898

99-
static void wsum_hess_init(expr *node)
99+
static void wsum_hess_init_impl(expr *node)
100100
{
101101
expr *x = node->left;
102102

103103
/* initialize child's wsum_hess */
104-
x->wsum_hess_init(x);
104+
wsum_hess_init(x);
105105

106106
/* for setting weight vector to evaluate hessian of child */
107107
node->work->dwork = (double *) calloc(x->size, sizeof(double));
@@ -166,8 +166,9 @@ expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs)
166166
index_expr *idx = (index_expr *) calloc(1, sizeof(index_expr));
167167
expr *node = &idx->base;
168168

169-
init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian,
170-
is_affine, wsum_hess_init, eval_wsum_hess, free_type_data);
169+
init_expr(node, d1, d2, child->n_vars, forward, jacobian_init_impl,
170+
eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess,
171+
free_type_data);
171172

172173
node->left = child;
173174
expr_retain(child);

src/affine/left_matmul.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ static void free_type_data(expr *node)
8181
lnode->csc_to_csr_work = NULL;
8282
}
8383

84-
static void jacobian_init(expr *node)
84+
static void jacobian_init_impl(expr *node)
8585
{
8686
expr *x = node->left;
8787
left_matmul_expr *lnode = (left_matmul_expr *) node;
8888

8989
/* initialize child's jacobian and precompute sparsity of its CSC */
90-
x->jacobian_init(x);
90+
jacobian_init(x);
9191
lnode->Jchild_CSC = csr_to_csc_fill_sparsity(x->jacobian, node->work->iwork);
9292

9393
/* precompute sparsity of this node's jacobian in CSC and CSR */
@@ -113,11 +113,11 @@ static void eval_jacobian(expr *node)
113113
csc_to_csr_fill_values(J_CSC, node->jacobian, lnode->csc_to_csr_work);
114114
}
115115

116-
static void wsum_hess_init(expr *node)
116+
static void wsum_hess_init_impl(expr *node)
117117
{
118118
/* initialize child's hessian */
119119
expr *x = node->left;
120-
x->wsum_hess_init(x);
120+
wsum_hess_init(x);
121121

122122
/* allocate this node's hessian with the same sparsity as child's */
123123
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
@@ -169,8 +169,8 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
169169
left_matmul_expr *lnode =
170170
(left_matmul_expr *) calloc(1, sizeof(left_matmul_expr));
171171
expr *node = &lnode->base;
172-
init_expr(node, d1, d2, u->n_vars, forward, jacobian_init, eval_jacobian,
173-
is_affine, wsum_hess_init, eval_wsum_hess, free_type_data);
172+
init_expr(node, d1, d2, u->n_vars, forward, jacobian_init_impl, eval_jacobian,
173+
is_affine, wsum_hess_init_impl, eval_wsum_hess, free_type_data);
174174
node->left = u;
175175
expr_retain(u);
176176

@@ -214,8 +214,8 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
214214
left_matmul_expr *lnode =
215215
(left_matmul_expr *) calloc(1, sizeof(left_matmul_expr));
216216
expr *node = &lnode->base;
217-
init_expr(node, d1, d2, u->n_vars, forward, jacobian_init, eval_jacobian,
218-
is_affine, wsum_hess_init, eval_wsum_hess, free_type_data);
217+
init_expr(node, d1, d2, u->n_vars, forward, jacobian_init_impl, eval_jacobian,
218+
is_affine, wsum_hess_init_impl, eval_wsum_hess, free_type_data);
219219
node->left = u;
220220
expr_retain(u);
221221

0 commit comments

Comments
 (0)