Skip to content

Commit 88d4fed

Browse files
committed
more mental preparation
1 parent 570b50c commit 88d4fed

23 files changed

Lines changed: 89 additions & 11 deletions

src/bivariate_full_dom/multiply.c

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,23 +128,29 @@ static void wsum_hess_init_impl(expr *node)
128128
else
129129
{
130130
/* chain rule: the Hessian is in this case given by
131-
wsum_hess = term1 + term1^T + term2 + term3 where
131+
wsum_hess = C + C^T + term2 + term3 where
132132
133-
* term1 = J_{g2}^T diag(w) J_{g1}
133+
* C = J_{g2}^T diag(w) J_{g1}
134134
* term2 = sum_k w_k g2_k H_{g1_k}
135135
* term3 = sum_k w_k g1_k H_{g2_k}
136136
137-
The two last terms are nonzero only if g1 and g2 are nonlinear.
137+
The two last terms are nonzero only if g1 and g2 are nonlinear.
138+
Here, we view multiply as the composition h(x) = f(g1(x), g2(x)) where f
139+
is the elementwise multiplication operator, and g1 and g2 are the left and
140+
right child nodes.
138141
*/
139142

143+
/* prepare sparsity pattern of csc conversion */
144+
jacobian_csc_init(x);
145+
jacobian_csc_init(y);
146+
140147
/* both are linear operators */
141-
CSC_Matrix *A = ((linear_op_expr *) x)->A_csc;
142-
CSC_Matrix *B = ((linear_op_expr *) y)->A_csc;
148+
CSC_Matrix *Jg1 = x->work->jacobian_csc;
149+
CSC_Matrix *Jg2 = y->work->jacobian_csc;
143150

144151
/* Allocate workspace for Hessian computation */
145152
elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node;
146-
CSR_Matrix *C; /* C = B^T diag(w) A */
147-
C = BTA_alloc(A, B);
153+
CSR_Matrix *C = BTA_alloc(Jg1, Jg2); /* C = Jg2^T diag(w) Jg1 */
148154
node->work->iwork = (int *) malloc(C->m * sizeof(int));
149155

150156
CSR_Matrix *CT = AT_alloc(C, node->work->iwork);
@@ -174,14 +180,42 @@ static void eval_wsum_hess(expr *node, const double *w)
174180
}
175181
else
176182
{
183+
// ----------------------------------------------------------------------
184+
// convert Jacobians of children to CSC format
185+
// (we only need to do this once if the child is affine)
186+
// TODO: what if we have parameters? Should we set jacobian_csc_filled
187+
// to false whenever parameters change value?
188+
// ----------------------------------------------------------------------
189+
if (!x->work->jacobian_csc_filled)
190+
{
191+
csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc,
192+
x->work->csc_work);
193+
194+
if (x->is_affine(x))
195+
{
196+
x->work->jacobian_csc_filled = true;
197+
}
198+
}
199+
200+
if (!y->work->jacobian_csc_filled)
201+
{
202+
csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc,
203+
y->work->csc_work);
204+
205+
if (y->is_affine(y))
206+
{
207+
y->work->jacobian_csc_filled = true;
208+
}
209+
}
210+
177211
/* both are linear operators */
178-
CSC_Matrix *A = ((linear_op_expr *) x)->A_csc;
179-
CSC_Matrix *B = ((linear_op_expr *) y)->A_csc;
212+
CSC_Matrix *Jg1 = x->work->jacobian_csc;
213+
CSC_Matrix *Jg2 = y->work->jacobian_csc;
180214
CSR_Matrix *C = ((elementwise_mult_expr *) node)->CSR_work1;
181215
CSR_Matrix *CT = ((elementwise_mult_expr *) node)->CSR_work2;
182216

183217
/* Compute C = B^T diag(w) A */
184-
BTDA_fill_values(A, B, w, C);
218+
BTDA_fill_values(Jg1, Jg2, w, C);
185219

186220
/* Compute CT = C^T = A^T diag(w) B */
187221
AT_fill_values(C, CT, node->work->iwork);

tests/wsum_hess/affine/test_const_scalar_mult.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const char *test_wsum_hess_const_scalar_mult_log_vector(void)
2727
y->forward(y, u_vals);
2828

2929
/* Initialize and evaluate weighted Hessian with w = [1.0, 0.5, 0.25] */
30+
jacobian_init(y);
3031
wsum_hess_init(y);
3132
double w[3] = {1.0, 0.5, 0.25};
3233
y->eval_wsum_hess(y, w);
@@ -72,6 +73,7 @@ const char *test_wsum_hess_const_scalar_mult_log_matrix(void)
7273
y->forward(y, u_vals);
7374

7475
/* Initialize and evaluate weighted Hessian with w = [1.0, 1.0, 1.0, 1.0] */
76+
jacobian_init(y);
7577
wsum_hess_init(y);
7678
double w[4] = {1.0, 1.0, 1.0, 1.0};
7779
y->eval_wsum_hess(y, w);

tests/wsum_hess/affine/test_const_vector_mult.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const char *test_wsum_hess_const_vector_mult_log_vector(void)
2727
y->forward(y, u_vals);
2828

2929
/* Initialize and evaluate weighted Hessian with w = [1.0, 0.5, 0.25] */
30+
jacobian_init(y);
3031
wsum_hess_init(y);
3132
double w[3] = {1.0, 0.5, 0.25};
3233
y->eval_wsum_hess(y, w);
@@ -70,6 +71,7 @@ const char *test_wsum_hess_const_vector_mult_log_matrix(void)
7071
y->forward(y, u_vals);
7172

7273
/* Initialize and evaluate weighted Hessian with w = [1.0, 1.0, 1.0, 1.0] */
74+
jacobian_init(y);
7375
wsum_hess_init(y);
7476
double w[4] = {1.0, 1.0, 1.0, 1.0};
7577
y->eval_wsum_hess(y, w);

tests/wsum_hess/affine/test_hstack.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ const char *test_wsum_hess_hstack_matrix(void)
135135
expr *hstack_node = new_hstack(args, 4, 18);
136136

137137
hstack_node->forward(hstack_node, u_vals);
138+
jacobian_init(hstack_node);
138139
wsum_hess_init(hstack_node);
139140
hstack_node->eval_wsum_hess(hstack_node, w);
140141

tests/wsum_hess/affine/test_transpose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const char *test_wsum_hess_transpose(void)
1818

1919
double u[8] = {1, 3, 2, 4, 5, 7, 6, 8};
2020
XYT->forward(XYT, u);
21+
jacobian_init(XYT);
2122
wsum_hess_init(XYT);
2223
double w[4] = {1, 2, 3, 4};
2324
XYT->eval_wsum_hess(XYT, w);

tests/wsum_hess/bivariate_full_dom/test_matmul.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ const char *test_wsum_hess_matmul(void)
4444

4545
/* Forward pass and Hessian initialization */
4646
Z->forward(Z, u_vals);
47+
jacobian_init(Z);
4748
wsum_hess_init(Z);
4849
Z->eval_wsum_hess(Z, w);
4950

@@ -144,6 +145,7 @@ const char *test_wsum_hess_matmul_yx(void)
144145

145146
/* Forward pass and Hessian initialization */
146147
Z->forward(Z, u_vals);
148+
jacobian_init(Z);
147149
wsum_hess_init(Z);
148150
Z->eval_wsum_hess(Z, w);
149151

tests/wsum_hess/bivariate_full_dom/test_multiply.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const char *test_wsum_hess_multiply_1(void)
2222
expr *node = new_elementwise_mult(x, y);
2323

2424
node->forward(node, u_vals);
25+
jacobian_init(node);
2526
wsum_hess_init(node);
2627
node->eval_wsum_hess(node, w);
2728

@@ -79,6 +80,7 @@ const char *test_wsum_hess_multiply_sparse_random(void)
7980
mult_node->forward(mult_node, u_vals);
8081

8182
/* Initialize and evaluate Hessian */
83+
jacobian_init(mult_node);
8284
wsum_hess_init(mult_node);
8385
double w[5] = {0.50646339, 0.44756224, 0.67295241, 0.16424956, 0.03031469};
8486
mult_node->eval_wsum_hess(mult_node, w);
@@ -160,6 +162,7 @@ const char *test_wsum_hess_multiply_linear_ops(void)
160162
mult_node->forward(mult_node, u_vals);
161163

162164
/* Initialize Hessian structure */
165+
jacobian_init(mult_node);
163166
wsum_hess_init(mult_node);
164167

165168
/* Evaluate Hessian with weights */
@@ -207,8 +210,9 @@ const char *test_wsum_hess_multiply_2(void)
207210
expr *y = new_variable(3, 1, 3, 12);
208211
expr *node = new_elementwise_mult(x, y);
209212

210-
node->forward(node, u_vals);
213+
jacobian_init(node);
211214
wsum_hess_init(node);
215+
node->forward(node, u_vals);
212216
node->eval_wsum_hess(node, w);
213217

214218
int expected_p[13] = {0, 0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 6, 6};

tests/wsum_hess/bivariate_restricted_dom/test_quad_over_lin.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ const char *test_wsum_hess_quad_over_lin_xy(void)
1717
expr *node = new_quad_over_lin(x, y);
1818

1919
node->forward(node, u_vals);
20+
jacobian_init(node);
2021
wsum_hess_init(node);
2122
node->eval_wsum_hess(node, &w);
2223

@@ -46,6 +47,7 @@ const char *test_wsum_hess_quad_over_lin_yx(void)
4647
expr *node = new_quad_over_lin(x, y);
4748

4849
node->forward(node, u_vals);
50+
jacobian_init(node);
4951
wsum_hess_init(node);
5052
node->eval_wsum_hess(node, &w);
5153

tests/wsum_hess/bivariate_restricted_dom/test_rel_entr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ const char *test_wsum_hess_rel_entr_1(void)
2121
expr *node = new_rel_entr_vector_args(x, y);
2222

2323
node->forward(node, u_vals);
24+
jacobian_init(node);
2425
wsum_hess_init(node);
2526
node->eval_wsum_hess(node, w);
2627

@@ -52,6 +53,7 @@ const char *test_wsum_hess_rel_entr_2(void)
5253
expr *node = new_rel_entr_vector_args(x, y);
5354

5455
node->forward(node, u_vals);
56+
jacobian_init(node);
5557
wsum_hess_init(node);
5658
node->eval_wsum_hess(node, w);
5759

@@ -83,6 +85,7 @@ const char *test_wsum_hess_rel_entr_matrix(void)
8385
expr *node = new_rel_entr_vector_args(x, y);
8486

8587
node->forward(node, u_vals);
88+
jacobian_init(node);
8689
wsum_hess_init(node);
8790
node->eval_wsum_hess(node, w);
8891

tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_scalar_vector.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ const char *test_wsum_hess_rel_entr_scalar_vector(void)
1616
expr *node = new_rel_entr_first_arg_scalar(x, y);
1717

1818
node->forward(node, u_vals);
19+
jacobian_init(node);
1920
wsum_hess_init(node);
2021
node->eval_wsum_hess(node, w);
2122

0 commit comments

Comments
 (0)