Skip to content

Commit 4698872

Browse files
authored
Merge pull request #2 from dance858/refactor-polymorphism
[Ready for review] Refactor polymorphism
2 parents f3d17b2 + 3a23c26 commit 4698872

30 files changed

Lines changed: 397 additions & 311 deletions

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
1. power should be double
2-
2. can we reuse calculations, like in hessian of logistic
31
3. more tests for chain rule elementwise univariate hessian
42
4. in the refactor, add consts
53
5. multiply with one constant vector/scalar argument
6-
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.
4+
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.
5+
7. Must be able to compute jacobian and hessian of A @ phi(x), so linear operator needs other code! This requires new infrastructure, I think.

include/affine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
#define AFFINE_H
33

44
#include "expr.h"
5+
#include "subexpr.h"
56
#include "utils/CSR_Matrix.h"
67

78
expr *new_linear(expr *u, const CSR_Matrix *A);
89

910
expr *new_add(expr *left, expr *right);
11+
1012
expr *new_sum(expr *child, int axis);
1113
expr *new_hstack(expr **args, int n_args, int n_vars);
1214

include/elementwise_univariate.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "expr.h"
55

6+
/* Helper function to initialize an elementwise expr (can be used with derived types)
7+
*/
8+
void init_elementwise(expr *node, expr *child);
9+
610
expr *new_exp(expr *child);
711
expr *new_log(expr *child);
812
expr *new_entr(expr *child);
@@ -14,19 +18,19 @@ expr *new_tanh(expr *child);
1418
expr *new_asinh(expr *child);
1519
expr *new_atanh(expr *child);
1620
expr *new_logistic(expr *child);
17-
expr *new_power(expr *child, int p);
21+
expr *new_power(expr *child, double p);
1822
expr *new_xexp(expr *child);
1923

2024
/* the jacobian and wsum_hess for elementwise univariate atoms are always
2125
initialized in the same way and implement the chain rule in the same way */
2226
void jacobian_init_elementwise(expr *node);
2327
void eval_jacobian_elementwise(expr *node);
2428
void wsum_hess_init_elementwise(expr *node);
25-
void eval_wsum_hess_elementwise(expr *node, double *w);
29+
void eval_wsum_hess_elementwise(expr *node, const double *w);
2630
expr *new_elementwise(expr *child);
2731

2832
/* no elementwise atoms are affine according to our convention,
2933
so we can have a common implementation */
30-
bool is_affine_elementwise(expr *node);
34+
bool is_affine_elementwise(const expr *node);
3135

3236
#endif /* ELEMENTWISE_UNIVARIATE_H */

include/expr.h

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,70 +7,58 @@
77
#include <stddef.h>
88

99
#define JAC_IDXS_NOT_SET -1
10-
11-
/* Forward declarations */
12-
struct expr;
13-
struct int_double_pair;
10+
#define NOT_A_VARIABLE -1
1411

1512
/* Function pointer types */
13+
struct expr;
1614
typedef void (*forward_fn)(struct expr *node, const double *u);
1715
typedef void (*jacobian_init_fn)(struct expr *node);
1816
typedef void (*wsum_hess_init_fn)(struct expr *node);
1917
typedef void (*eval_jacobian_fn)(struct expr *node);
20-
typedef void (*wsum_hess_fn)(struct expr *node, double *w);
18+
typedef void (*wsum_hess_fn)(struct expr *node, const double *w);
2119
typedef void (*local_jacobian_fn)(struct expr *node, double *out);
22-
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w);
23-
typedef bool (*is_affine_fn)(struct expr *node);
24-
25-
/* TODO: implement proper polymorphism */
20+
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double *w);
21+
typedef bool (*is_affine_fn)(const struct expr *node);
22+
typedef void (*free_type_data_fn)(struct expr *node);
2623

27-
/* Expression node structure */
24+
/* Base expression node structure - contains only common fields */
2825
typedef struct expr
2926
{
3027
// ------------------------------------------------------------------------
3128
// general quantities
3229
// ------------------------------------------------------------------------
33-
int d1, d2, size;
34-
int n_vars;
35-
int var_id;
36-
int refcount;
30+
int d1, d2, size, n_vars, refcount, var_id;
3731
struct expr *left;
3832
struct expr *right;
39-
struct expr **args; /* hstack can have multiple arguments */
40-
int n_args;
4133
double *dwork;
4234
int *iwork;
43-
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
44-
int p; /* power of power expression */
45-
int axis; /* axis for sum or similar operations */
46-
CSR_Matrix *Q; /* Q for quad_form */
4735

4836
// ------------------------------------------------------------------------
49-
// forward pass related quantities
37+
// oracle related quantities
5038
// ------------------------------------------------------------------------
5139
double *value;
52-
forward_fn forward;
53-
54-
// ------------------------------------------------------------------------
55-
// jacobian related quantities
56-
// ------------------------------------------------------------------------
5740
CSR_Matrix *jacobian;
5841
CSR_Matrix *wsum_hess;
59-
CSR_Matrix *CSR_work;
42+
forward_fn forward;
6043
jacobian_init_fn jacobian_init;
6144
wsum_hess_init_fn wsum_hess_init;
6245
eval_jacobian_fn eval_jacobian;
6346
wsum_hess_fn eval_wsum_hess;
64-
local_jacobian_fn local_jacobian;
65-
local_wsum_hess_fn local_wsum_hess;
66-
is_affine_fn is_affine;
6747

68-
// for every linear operator we store A in CSR and CSC
69-
CSC_Matrix *A_csc;
70-
CSR_Matrix *A_csr;
48+
// ------------------------------------------------------------------------
49+
// other things
50+
// ------------------------------------------------------------------------
51+
is_affine_fn is_affine;
52+
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
53+
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
54+
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
7155

7256
} expr;
7357

58+
void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
59+
jacobian_init_fn jacobian_init, eval_jacobian_fn eval_jacobian,
60+
is_affine_fn is_affine, free_type_data_fn free_type_data);
61+
7462
expr *new_expr(int d1, int d2, int n_vars);
7563
void free_expr(expr *node);
7664

include/other.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define OTHER_H
33

44
#include "expr.h"
5+
#include "subexpr.h"
56
#include "utils/CSR_Matrix.h"
67

78
expr *new_quad_form(expr *child, CSR_Matrix *Q);

include/subexpr.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#ifndef SUBEXPR_H
2+
#define SUBEXPR_H
3+
4+
#include "expr.h"
5+
#include "utils/CSC_Matrix.h"
6+
#include "utils/CSR_Matrix.h"
7+
8+
/* Forward declaration */
9+
struct int_double_pair;
10+
11+
/* Type-specific expression structures that "inherit" from expr */
12+
13+
/* Linear operator: y = A * x */
14+
typedef struct linear_op_expr
15+
{
16+
expr base;
17+
CSC_Matrix *A_csc;
18+
CSR_Matrix *A_csr;
19+
} linear_op_expr;
20+
21+
/* Power: y = x^p */
22+
typedef struct power_expr
23+
{
24+
expr base;
25+
double p;
26+
} power_expr;
27+
28+
/* Quadratic form: y = x'*Q*x */
29+
typedef struct quad_form_expr
30+
{
31+
expr base;
32+
CSR_Matrix *Q;
33+
} quad_form_expr;
34+
35+
/* Sum reduction along an axis */
36+
typedef struct sum_expr
37+
{
38+
expr base;
39+
int axis;
40+
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
41+
} sum_expr;
42+
43+
/* Horizontal stack (concatenate) */
44+
typedef struct hstack_expr
45+
{
46+
expr base;
47+
expr **args;
48+
int n_args;
49+
} hstack_expr;
50+
51+
#endif /* SUBEXPR_H */

include/utils/COO_Matrix.h

Lines changed: 0 additions & 16 deletions
This file was deleted.

include/utils/CSC_Matrix.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,9 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
3939
*/
4040
void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);
4141

42+
/* C = z^T A where A is in CSC format and C is assumed to have one row.
43+
* C must have column indices pre-computed. Fills in values of C only.
44+
*/
45+
void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C);
46+
4247
#endif /* CSC_MATRIX_H */

src/affine/add.c

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ static void wsum_hess_init(expr *node)
4545
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);
4646
}
4747

48-
static void eval_wsum_hess(expr *node, double *w)
48+
static void eval_wsum_hess(expr *node, const double *w)
4949
{
5050
/* evaluate children's wsum_hess */
5151
node->left->eval_wsum_hess(node->left, w);
@@ -55,20 +55,14 @@ static void eval_wsum_hess(expr *node, double *w)
5555
sum_csr_matrices(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);
5656
}
5757

58-
static bool is_affine(expr *node)
58+
static bool is_affine(const expr *node)
5959
{
6060
return node->left->is_affine(node->left) && node->right->is_affine(node->right);
6161
}
6262

6363
expr *new_add(expr *left, expr *right)
6464
{
65-
if (!left || !right) return NULL;
66-
if (left->d1 != right->d1) return NULL;
67-
if (left->d2 != right->d2) return NULL;
68-
6965
expr *node = new_expr(left->d1, left->d2, left->n_vars);
70-
if (!node) return NULL;
71-
7266
node->left = left;
7367
node->right = right;
7468
expr_retain(left);

src/affine/constant.c

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,16 @@ static void forward(expr *node, const double *u)
88
(void) u;
99
}
1010

11-
static bool is_affine(expr *node)
11+
static bool is_affine(const expr *node)
1212
{
1313
(void) node;
14-
return true; /* constant is affine */
14+
return true;
1515
}
1616

1717
expr *new_constant(int d1, int d2, int n_vars, const double *values)
1818
{
1919
expr *node = new_expr(d1, d2, n_vars);
20-
if (!node) return NULL;
21-
22-
/* Copy constant values */
2320
memcpy(node->value, values, d1 * d2 * sizeof(double));
24-
2521
node->forward = forward;
2622
node->is_affine = is_affine;
2723

0 commit comments

Comments
 (0)