|
7 | 7 | #include <stddef.h> |
8 | 8 |
|
9 | 9 | #define JAC_IDXS_NOT_SET -1 |
10 | | - |
11 | | -/* Forward declarations */ |
12 | | -struct expr; |
13 | | -struct int_double_pair; |
| 10 | +#define NOT_A_VARIABLE -1 |
14 | 11 |
|
15 | 12 | /* Function pointer types */ |
| 13 | +struct expr; |
16 | 14 | typedef void (*forward_fn)(struct expr *node, const double *u); |
17 | 15 | typedef void (*jacobian_init_fn)(struct expr *node); |
18 | 16 | typedef void (*wsum_hess_init_fn)(struct expr *node); |
19 | 17 | typedef void (*eval_jacobian_fn)(struct expr *node); |
20 | | -typedef void (*wsum_hess_fn)(struct expr *node, double *w); |
| 18 | +typedef void (*wsum_hess_fn)(struct expr *node, const double *w); |
21 | 19 | typedef void (*local_jacobian_fn)(struct expr *node, double *out); |
22 | | -typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w); |
23 | | -typedef bool (*is_affine_fn)(struct expr *node); |
24 | | - |
25 | | -/* TODO: implement proper polymorphism */ |
| 20 | +typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double *w); |
| 21 | +typedef bool (*is_affine_fn)(const struct expr *node); |
| 22 | +typedef void (*free_type_data_fn)(struct expr *node); |
26 | 23 |
|
27 | | -/* Expression node structure */ |
| 24 | +/* Base expression node structure - contains only common fields */ |
28 | 25 | typedef struct expr |
29 | 26 | { |
30 | 27 | // ------------------------------------------------------------------------ |
31 | 28 | // general quantities |
32 | 29 | // ------------------------------------------------------------------------ |
33 | | - int d1, d2, size; |
34 | | - int n_vars; |
35 | | - int var_id; |
36 | | - int refcount; |
| 30 | + int d1, d2, size, n_vars, refcount, var_id; |
37 | 31 | struct expr *left; |
38 | 32 | struct expr *right; |
39 | | - struct expr **args; /* hstack can have multiple arguments */ |
40 | | - int n_args; |
41 | 33 | double *dwork; |
42 | 34 | int *iwork; |
43 | | - struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */ |
44 | | - int p; /* power of power expression */ |
45 | | - int axis; /* axis for sum or similar operations */ |
46 | | - CSR_Matrix *Q; /* Q for quad_form */ |
47 | 35 |
|
48 | 36 | // ------------------------------------------------------------------------ |
49 | | - // forward pass related quantities |
| 37 | + // oracle related quantities |
50 | 38 | // ------------------------------------------------------------------------ |
51 | 39 | double *value; |
52 | | - forward_fn forward; |
53 | | - |
54 | | - // ------------------------------------------------------------------------ |
55 | | - // jacobian related quantities |
56 | | - // ------------------------------------------------------------------------ |
57 | 40 | CSR_Matrix *jacobian; |
58 | 41 | CSR_Matrix *wsum_hess; |
59 | | - CSR_Matrix *CSR_work; |
| 42 | + forward_fn forward; |
60 | 43 | jacobian_init_fn jacobian_init; |
61 | 44 | wsum_hess_init_fn wsum_hess_init; |
62 | 45 | eval_jacobian_fn eval_jacobian; |
63 | 46 | wsum_hess_fn eval_wsum_hess; |
64 | | - local_jacobian_fn local_jacobian; |
65 | | - local_wsum_hess_fn local_wsum_hess; |
66 | | - is_affine_fn is_affine; |
67 | 47 |
|
68 | | - // for every linear operator we store A in CSR and CSC |
69 | | - CSC_Matrix *A_csc; |
70 | | - CSR_Matrix *A_csr; |
| 48 | + // ------------------------------------------------------------------------ |
| 49 | + // other things |
| 50 | + // ------------------------------------------------------------------------ |
| 51 | + is_affine_fn is_affine; |
| 52 | + local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/ |
| 53 | + local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/ |
| 54 | + free_type_data_fn free_type_data; /* Cleanup for type-specific fields */ |
71 | 55 |
|
72 | 56 | } expr; |
73 | 57 |
|
| 58 | +void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, |
| 59 | + jacobian_init_fn jacobian_init, eval_jacobian_fn eval_jacobian, |
| 60 | + is_affine_fn is_affine, free_type_data_fn free_type_data); |
| 61 | + |
74 | 62 | expr *new_expr(int d1, int d2, int n_vars); |
75 | 63 | void free_expr(expr *node); |
76 | 64 |
|
|
0 commit comments