Skip to content

Commit 105d2ee

Browse files
committed
started refactoring polymorphism
1 parent f3d17b2 commit 105d2ee

12 files changed

Lines changed: 394 additions & 122 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
3. more tests for chain rule elementwise univariate hessian
44
4. in the refactor, add consts
55
5. multiply with one constant vector/scalar argument
6-
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.
6+
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.
7+
7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code?

include/affine.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,20 @@
44
#include "expr.h"
55
#include "utils/CSR_Matrix.h"
66

7+
/* Helper function to initialize a linear operator expr (can be used with derived
8+
* types) */
9+
void init_linear_op(expr *node, expr *child, int d1, int d2);
10+
711
expr *new_linear(expr *u, const CSR_Matrix *A);
812

913
expr *new_add(expr *left, expr *right);
14+
15+
/* Helper function to initialize a sum expr (can be used with derived types) */
16+
void init_sum(expr *node, expr *child, int d1);
17+
18+
/* Helper function to initialize an hstack expr (can be used with derived types) */
19+
void init_hstack(expr *node, int d1, int d2, int n_vars);
20+
1021
expr *new_sum(expr *child, int axis);
1122
expr *new_hstack(expr **args, int n_args, int n_vars);
1223

include/elementwise_univariate.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "expr.h"
55

6+
/* Helper function to initialize an elementwise expr (can be used with derived types)
7+
*/
8+
void init_elementwise(expr *node, expr *child);
9+
610
expr *new_exp(expr *child);
711
expr *new_log(expr *child);
812
expr *new_entr(expr *child);

include/expr.h

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,55 +21,82 @@ typedef void (*wsum_hess_fn)(struct expr *node, double *w);
2121
typedef void (*local_jacobian_fn)(struct expr *node, double *out);
2222
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w);
2323
typedef bool (*is_affine_fn)(struct expr *node);
24+
typedef void (*free_type_data_fn)(struct expr *node);
2425

25-
/* TODO: implement proper polymorphism */
26-
27-
/* Expression node structure */
26+
/* Base expression node structure - contains only common fields */
2827
typedef struct expr
2928
{
3029
// ------------------------------------------------------------------------
3130
// general quantities
3231
// ------------------------------------------------------------------------
33-
int d1, d2, size;
34-
int n_vars;
35-
int var_id;
36-
int refcount;
32+
int d1, d2, size, n_vars, refcount, var_id;
3733
struct expr *left;
3834
struct expr *right;
39-
struct expr **args; /* hstack can have multiple arguments */
40-
int n_args;
4135
double *dwork;
4236
int *iwork;
43-
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
44-
int p; /* power of power expression */
45-
int axis; /* axis for sum or similar operations */
46-
CSR_Matrix *Q; /* Q for quad_form */
4737

4838
// ------------------------------------------------------------------------
49-
// forward pass related quantities
39+
// oracle related quantities
5040
// ------------------------------------------------------------------------
5141
double *value;
52-
forward_fn forward;
53-
54-
// ------------------------------------------------------------------------
55-
// jacobian related quantities
56-
// ------------------------------------------------------------------------
5742
CSR_Matrix *jacobian;
5843
CSR_Matrix *wsum_hess;
59-
CSR_Matrix *CSR_work;
44+
forward_fn forward;
6045
jacobian_init_fn jacobian_init;
6146
wsum_hess_init_fn wsum_hess_init;
6247
eval_jacobian_fn eval_jacobian;
6348
wsum_hess_fn eval_wsum_hess;
64-
local_jacobian_fn local_jacobian;
65-
local_wsum_hess_fn local_wsum_hess;
49+
50+
// ------------------------------------------------------------------------
51+
// other things
52+
// ------------------------------------------------------------------------
53+
CSR_Matrix *CSR_work;
6654
is_affine_fn is_affine;
55+
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
56+
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
57+
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
6758

68-
// for every linear operator we store A in CSR and CSC
59+
} expr;
60+
61+
/* Type-specific expression structures that "inherit" from expr */
62+
63+
/* Linear operator: y = A * x */
64+
typedef struct linear_op_expr
65+
{
66+
expr base; /* MUST be first member for casting to work */
6967
CSC_Matrix *A_csc;
7068
CSR_Matrix *A_csr;
69+
} linear_op_expr;
7170

72-
} expr;
71+
/* Power: y = x^p */
72+
typedef struct power_expr
73+
{
74+
expr base; /* MUST be first member for casting to work */
75+
int p;
76+
} power_expr;
77+
78+
/* Quadratic form: y = x'*Q*x */
79+
typedef struct quad_form_expr
80+
{
81+
expr base; /* MUST be first member for casting to work */
82+
CSR_Matrix *Q;
83+
} quad_form_expr;
84+
85+
/* Sum reduction along an axis */
86+
typedef struct sum_expr
87+
{
88+
expr base; /* MUST be first member for casting to work */
89+
int axis;
90+
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
91+
} sum_expr;
92+
93+
/* Horizontal stack (concatenate) */
94+
typedef struct hstack_expr
95+
{
96+
expr base; /* MUST be first member for casting to work */
97+
expr **args;
98+
int n_args;
99+
} hstack_expr;
73100

74101
expr *new_expr(int d1, int d2, int n_vars);
75102
void free_expr(expr *node);

include/other.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "expr.h"
55
#include "utils/CSR_Matrix.h"
66

7+
/* Helper function to initialize a quad_form expr (can be used with derived types) */
8+
void init_quad_form(expr *node, expr *child);
9+
710
expr *new_quad_form(expr *child, CSR_Matrix *Q);
811

912
#endif /* OTHER_H */

src/affine/hstack.c

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,55 @@
11
#include "affine.h"
22
#include <assert.h>
3+
#include <stdlib.h>
34
#include <string.h>
45

56
static void forward(expr *node, const double *u)
67
{
8+
hstack_expr *hnode = (hstack_expr *) node;
79

810
/* children's forward passes */
9-
for (int i = 0; i < node->n_args; i++)
11+
for (int i = 0; i < hnode->n_args; i++)
1012
{
11-
node->args[i]->forward(node->args[i], u);
13+
hnode->args[i]->forward(hnode->args[i], u);
1214
}
1315

1416
/* concatenate values horizontally */
1517
int offset = 0;
16-
for (int i = 0; i < node->n_args; i++)
18+
for (int i = 0; i < hnode->n_args; i++)
1719
{
18-
expr *child = node->args[i];
20+
expr *child = hnode->args[i];
1921
memcpy(node->value + offset, child->value, child->size * sizeof(double));
2022
offset += child->size;
2123
}
2224
}
2325

2426
static void jacobian_init(expr *node)
2527
{
28+
hstack_expr *hnode = (hstack_expr *) node;
29+
2630
/* initialize children's jacobians */
2731
int nnz = 0;
28-
for (int i = 0; i < node->n_args; i++)
32+
for (int i = 0; i < hnode->n_args; i++)
2933
{
30-
node->args[i]->jacobian_init(node->args[i]);
31-
nnz += node->args[i]->jacobian->nnz;
34+
hnode->args[i]->jacobian_init(hnode->args[i]);
35+
nnz += hnode->args[i]->jacobian->nnz;
3236
}
3337

3438
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz);
3539
}
3640

3741
static void eval_jacobian(expr *node)
3842
{
43+
hstack_expr *hnode = (hstack_expr *) node;
44+
3945
/* evaluate children's jacobians */
4046
int row_offset = 0;
4147
CSR_Matrix *A = node->jacobian;
4248
A->nnz = 0;
4349

44-
for (int i = 0; i < node->n_args; i++)
50+
for (int i = 0; i < hnode->n_args; i++)
4551
{
46-
expr *child = node->args[i];
52+
expr *child = hnode->args[i];
4753
child->eval_jacobian(child);
4854
CSR_Matrix *B = child->jacobian;
4955

@@ -65,16 +71,55 @@ static void eval_jacobian(expr *node)
6571

6672
static bool is_affine(expr *node)
6773
{
68-
for (int i = 0; i < node->n_args; i++)
74+
hstack_expr *hnode = (hstack_expr *) node;
75+
76+
for (int i = 0; i < hnode->n_args; i++)
6977
{
70-
if (!node->args[i]->is_affine(node->args[i]))
78+
if (!hnode->args[i]->is_affine(hnode->args[i]))
7179
{
7280
return false;
7381
}
7482
}
7583
return true;
7684
}
7785

86+
static void free_type_data(expr *node)
87+
{
88+
hstack_expr *hnode = (hstack_expr *) node;
89+
for (int i = 0; i < hnode->n_args; i++)
90+
{
91+
free_expr(hnode->args[i]);
92+
}
93+
}
94+
95+
/* Helper function to initialize an hstack expr */
96+
void init_hstack(expr *node, int d1, int d2, int n_vars)
97+
{
98+
node->d1 = d1;
99+
node->d2 = d2;
100+
node->size = d1 * d2;
101+
node->n_vars = n_vars;
102+
node->var_id = -1;
103+
node->refcount = 1;
104+
node->left = NULL;
105+
node->right = NULL;
106+
node->dwork = NULL;
107+
node->iwork = NULL;
108+
node->value = (double *) calloc(node->size, sizeof(double));
109+
node->jacobian = NULL;
110+
node->wsum_hess = NULL;
111+
node->CSR_work = NULL;
112+
node->jacobian_init = jacobian_init;
113+
node->wsum_hess_init = NULL;
114+
node->eval_jacobian = eval_jacobian;
115+
node->eval_wsum_hess = NULL;
116+
node->local_jacobian = NULL;
117+
node->local_wsum_hess = NULL;
118+
node->forward = forward;
119+
node->is_affine = is_affine;
120+
node->free_type_data = free_type_data;
121+
}
122+
78123
expr *new_hstack(expr **args, int n_args, int n_vars)
79124
{
80125
/* compute second dimension */
@@ -84,20 +129,30 @@ expr *new_hstack(expr **args, int n_args, int n_vars)
84129
d2 += args[i]->d2;
85130
}
86131

87-
expr *node = new_expr(args[0]->d1, d2, n_vars);
88-
if (!node) return NULL;
89-
node->args = args;
90-
node->n_args = n_args;
132+
/* Allocate the type-specific struct */
133+
hstack_expr *hnode = (hstack_expr *) malloc(sizeof(hstack_expr));
134+
if (!hnode) return NULL;
135+
136+
expr *node = &hnode->base;
137+
138+
/* Initialize base hstack fields */
139+
init_hstack(node, args[0]->d1, d2, n_vars);
140+
141+
/* Check if allocation succeeded */
142+
if (!node->value)
143+
{
144+
free(hnode);
145+
return NULL;
146+
}
147+
148+
/* Set type-specific fields */
149+
hnode->args = args;
150+
hnode->n_args = n_args;
91151

92152
for (int i = 0; i < n_args; i++)
93153
{
94154
expr_retain(args[i]);
95155
}
96156

97-
node->forward = forward;
98-
node->is_affine = is_affine;
99-
node->jacobian_init = jacobian_init;
100-
node->eval_jacobian = eval_jacobian;
101-
102157
return node;
103158
}

src/affine/linear_op.c

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,70 @@ static bool is_affine(expr *node)
1616
return node->left->is_affine(node->left);
1717
}
1818

19-
expr *new_linear(expr *u, const CSR_Matrix *A)
19+
static void free_type_data(expr *node)
2020
{
21-
expr *node = new_expr(A->m, 1, u->n_vars);
22-
if (!node) return NULL;
21+
linear_op_expr *lin_node = (linear_op_expr *) node;
22+
free_csr_matrix(lin_node->A_csr);
23+
free_csc_matrix(lin_node->A_csc);
24+
}
2325

24-
node->left = u;
25-
expr_retain(u);
26+
/* Helper function to initialize a linear operator expr */
27+
void init_linear_op(expr *node, expr *child, int d1, int d2)
28+
{
29+
node->d1 = d1;
30+
node->d2 = d2;
31+
node->size = d1 * d2;
32+
node->n_vars = child->n_vars;
33+
node->var_id = -1;
34+
node->refcount = 1;
35+
node->left = child;
36+
node->right = NULL;
37+
node->dwork = NULL;
38+
node->iwork = NULL;
39+
node->value = (double *) calloc(node->size, sizeof(double));
40+
node->jacobian = NULL;
41+
node->wsum_hess = NULL;
42+
node->CSR_work = NULL;
43+
node->jacobian_init = NULL;
44+
node->wsum_hess_init = NULL;
45+
node->eval_jacobian = NULL;
46+
node->eval_wsum_hess = NULL;
47+
node->local_jacobian = NULL;
48+
node->local_wsum_hess = NULL;
2649
node->forward = forward;
2750
node->is_affine = is_affine;
51+
node->free_type_data = free_type_data;
52+
53+
expr_retain(child);
54+
}
55+
56+
expr *new_linear(expr *u, const CSR_Matrix *A)
57+
{
58+
/* Allocate the type-specific struct */
59+
linear_op_expr *lin_node = (linear_op_expr *) malloc(sizeof(linear_op_expr));
60+
if (!lin_node) return NULL;
61+
62+
expr *node = &lin_node->base;
63+
64+
/* Initialize base linear operator fields */
65+
init_linear_op(node, u, A->m, 1);
66+
67+
/* Check if allocation succeeded */
68+
if (!node->value)
69+
{
70+
free(lin_node);
71+
return NULL;
72+
}
2873

2974
/* allocate jacobian and copy A into it */
3075
// TODO: this should eventually be removed
3176
node->jacobian = new_csr_matrix(A->m, A->n, A->nnz);
3277
copy_csr_matrix(A, node->jacobian);
3378

34-
node->A_csr = new_csr_matrix(A->m, A->n, A->nnz);
35-
copy_csr_matrix(A, node->A_csr);
36-
node->A_csc = csr_to_csc(A);
79+
/* Initialize type-specific fields */
80+
lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz);
81+
copy_csr_matrix(A, lin_node->A_csr);
82+
lin_node->A_csc = csr_to_csc(A);
3783

3884
return node;
3985
}

0 commit comments

Comments
 (0)