Skip to content

Commit f8f3fdc

Browse files
committed
finished cleaning up elementwise univariate and affine atoms
1 parent 7062b8f commit f8f3fdc

6 files changed

Lines changed: 8 additions & 15 deletions

File tree

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
1. power should be double
2-
2. can we reuse calculations, like in hessian of logistic
31
3. more tests for chain rule elementwise univariate hessian
42
4. in the refactor, add consts
53
5. multiply with one constant vector/scalar argument
64
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.
7-
7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code?
5+
7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code! This requires new infrastructure, I think.

include/elementwise_univariate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ expr *new_tanh(expr *child);
1818
expr *new_asinh(expr *child);
1919
expr *new_atanh(expr *child);
2020
expr *new_logistic(expr *child);
21-
expr *new_power(expr *child, int p);
21+
expr *new_power(expr *child, double p);
2222
expr *new_xexp(expr *child);
2323

2424
/* the jacobian and wsum_hess for elementwise univariate atoms are always

include/subexpr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ typedef struct linear_op_expr
2222
typedef struct power_expr
2323
{
2424
expr base;
25-
int p;
25+
double p;
2626
} power_expr;
2727

2828
/* Quadratic form: y = x'*Q*x */

src/elementwise_univariate/power.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static void local_wsum_hess(expr *node, double *out, const double *w)
4040
}
4141
}
4242

43-
expr *new_power(expr *child, int p)
43+
expr *new_power(expr *child, double p)
4444
{
4545
/* Allocate the type-specific struct */
4646
power_expr *pnode = (power_expr *) calloc(1, sizeof(power_expr));

src/other/quad_form.c

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
#include <math.h>
66
#include <stdlib.h>
77

8-
/* Note: Q is not freed here because it's owned by the caller */
9-
108
static void forward(expr *node, const double *u)
119
{
1210
expr *x = node->left;
1311

14-
/* children's forward passes */
12+
/* child's forward pass */
1513
x->forward(x, u);
1614

1715
/* local forward pass */
@@ -45,8 +43,7 @@ static void jacobian_init(expr *node)
4543
else /* x is not a variable */
4644
{
4745
/* compute required allocation and allocate jacobian */
48-
bool *col_nz = (bool *) calloc(
49-
node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/
46+
bool *col_nz = (bool *) calloc(node->n_vars, sizeof(bool));
5047
int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz);
5148
node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1);
5249

@@ -65,9 +62,6 @@ static void jacobian_init(expr *node)
6562

6663
node->jacobian->p[0] = 0;
6764
node->jacobian->p[1] = node->jacobian->nnz;
68-
69-
/* Cast x to linear operator to use its A_csc in eval_jacobian */
70-
node->iwork = (int *) malloc(sizeof(int));
7165
}
7266
}
7367

@@ -109,6 +103,7 @@ expr *new_quad_form(expr *left, CSR_Matrix *Q)
109103
expr *node = &qnode->base;
110104

111105
/* Initialize base fields */
106+
assert(left->d2 == 1);
112107
init_expr(node, left->d1, 1, left->n_vars, forward, jacobian_init, eval_jacobian,
113108
NULL, NULL);
114109

tests/wsum_hess/test_power.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const char *test_wsum_hess_power()
1515
double w[3] = {1.0, 2.0, 3.0};
1616

1717
expr *x = new_variable(3, 1, 0, 3);
18-
expr *power_node = new_power(x, 3);
18+
expr *power_node = new_power(x, 3.0);
1919
power_node->forward(power_node, u_vals);
2020
power_node->wsum_hess_init(power_node);
2121
power_node->eval_wsum_hess(power_node, w);

0 commit comments

Comments
 (0)