Skip to content

Commit 9d87ea1

Browse files
committed
Merge branch 'main' into adds-more-affine-atoms
# Conflicts: # include/atoms/affine.h # src/atoms/affine/diag_mat.c # src/atoms/affine/upper_tri.c # tests/all_tests.c
2 parents a5dfe1b + bcdb0f0 commit 9d87ea1

185 files changed

Lines changed: 6973 additions & 3522 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ set(CMAKE_C_STANDARD 99)
44
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
55

66
set(DIFF_ENGINE_VERSION_MAJOR 0)
7-
set(DIFF_ENGINE_VERSION_MINOR 1)
8-
set(DIFF_ENGINE_VERSION_PATCH 5)
7+
set(DIFF_ENGINE_VERSION_MINOR 3)
8+
set(DIFF_ENGINE_VERSION_PATCH 0)
99
set(DIFF_ENGINE_VERSION "${DIFF_ENGINE_VERSION_MAJOR}.${DIFF_ENGINE_VERSION_MINOR}.${DIFF_ENGINE_VERSION_PATCH}")
1010
add_compile_definitions(DIFF_ENGINE_VERSION="${DIFF_ENGINE_VERSION}")
1111

@@ -103,6 +103,7 @@ if(NOT SKBUILD)
103103
add_executable(all_tests
104104
tests/all_tests.c
105105
tests/test_helpers.c
106+
tests/numerical_diff.c
106107
)
107108
target_link_libraries(all_tests dnlp_diff)
108109

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@
2222
#include "subexpr.h"
2323
#include "utils/CSR_Matrix.h"
2424

25-
expr *new_linear(expr *u, const CSR_Matrix *A, const double *b);
26-
2725
expr *new_add(expr *left, expr *right);
2826
expr *new_neg(expr *child);
2927

3028
expr *new_sum(expr *child, int axis);
3129
expr *new_hstack(expr **args, int n_args, int n_vars);
30+
expr *new_vstack(expr **args, int n_args, int n_vars);
3231
expr *new_promote(expr *child, int d1, int d2);
3332
expr *new_trace(expr *child);
3433

35-
expr *new_constant(int d1, int d2, int n_vars, const double *values);
34+
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values);
3635
expr *new_variable(int d1, int d2, int var_id, int n_vars);
3736

3837
expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
@@ -43,4 +42,27 @@ expr *new_diag_mat(expr *child);
4342
expr *new_upper_tri(expr *child);
4443
expr *new_transpose(expr *child);
4544

45+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
46+
* sparse matrix. param_node is NULL for fixed constants. */
47+
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
48+
49+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
50+
* dense matrix (row-major, m x n). Uses CBLAS for efficient computation. */
51+
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
52+
const double *data);
53+
54+
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter
55+
* matrix. */
56+
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
57+
58+
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
59+
const double *data);
60+
61+
/* Scalar multiplication: a * f(x) where a comes from param_node */
62+
expr *new_scalar_mult(expr *param_node, expr *child);
63+
64+
/* Vector elementwise multiplication: a . f(x) where a comes from
65+
* param_node */
66+
expr *new_vector_mult(expr *param_node, expr *child);
67+
4668
#endif /* AFFINE_H */

include/atoms/bivariate_full_dom.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef BIVARIATE_FULL_DOM_H
2+
#define BIVARIATE_FULL_DOM_H
3+
4+
#include "expr.h"
5+
6+
expr *new_elementwise_mult(expr *left, expr *right);
7+
8+
/* Matrix multiplication: Z = X @ Y */
9+
expr *new_matmul(expr *x, expr *y);
10+
11+
#endif /* BIVARIATE_FULL_DOM_H */
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef BIVARIATE_RESTRICTED_DOM_H
2+
#define BIVARIATE_RESTRICTED_DOM_H
3+
4+
#include "expr.h"
5+
6+
expr *new_quad_over_lin(expr *left, expr *right);
7+
expr *new_rel_entr_vector_args(expr *left, expr *right);
8+
expr *new_rel_entr_first_arg_scalar(expr *left, expr *right);
9+
expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
10+
11+
#endif /* BIVARIATE_RESTRICTED_DOM_H */
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef ELEMENTWISE_FULL_DOM_H
2+
#define ELEMENTWISE_FULL_DOM_H
3+
4+
#include "expr.h"
5+
6+
/* Helper function to initialize an elementwise expr
7+
* (can be used with derived types) */
8+
void init_elementwise(expr *node, expr *child);
9+
10+
expr *new_exp(expr *child);
11+
expr *new_sin(expr *child);
12+
expr *new_cos(expr *child);
13+
expr *new_sinh(expr *child);
14+
expr *new_tanh(expr *child);
15+
expr *new_asinh(expr *child);
16+
expr *new_logistic(expr *child);
17+
expr *new_power(expr *child, double p);
18+
expr *new_xexp(expr *child);
19+
expr *new_normal_cdf(expr *child);
20+
21+
/* the jacobian and wsum_hess for elementwise full domain
22+
atoms are always initialized in the same way and
23+
implement the chain rule in the same way */
24+
void jacobian_init_elementwise(expr *node);
25+
void eval_jacobian_elementwise(expr *node);
26+
void wsum_hess_init_elementwise(expr *node);
27+
void eval_wsum_hess_elementwise(expr *node, const double *w);
28+
expr *new_elementwise(expr *child);
29+
30+
/* no elementwise atoms are affine according to our
31+
convention, so we can have a common implementation */
32+
bool is_affine_elementwise(const expr *node);
33+
34+
#endif /* ELEMENTWISE_FULL_DOM_H */
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef ELEMENTWISE_RESTRICTED_DOM_H
2+
#define ELEMENTWISE_RESTRICTED_DOM_H
3+
4+
#include "expr.h"
5+
6+
/* Shared init functions for restricted domain atoms
7+
* (variable-child only, no linear operator support) */
8+
void jacobian_init_restricted(expr *node);
9+
void wsum_hess_init_restricted(expr *node);
10+
bool is_affine_restricted(const expr *node);
11+
expr *new_restricted(expr *child);
12+
13+
expr *new_log(expr *child);
14+
expr *new_entr(expr *child);
15+
expr *new_atanh(expr *child);
16+
expr *new_tan(expr *child);
17+
18+
#endif /* ELEMENTWISE_RESTRICTED_DOM_H */
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
#ifndef OTHER_H
19-
#define OTHER_H
18+
#ifndef NON_ELEMENTWISE_FULL_DOM_H
19+
#define NON_ELEMENTWISE_FULL_DOM_H
2020

2121
#include "expr.h"
2222
#include "subexpr.h"
@@ -33,4 +33,4 @@ expr *new_prod_axis_zero(expr *child);
3333
/* product of entries along axis=1 (rowwise products) */
3434
expr *new_prod_axis_one(expr *child);
3535

36-
#endif /* OTHER_H */
36+
#endif /* NON_ELEMENTWISE_FULL_DOM_H */

include/bivariate.h

Lines changed: 0 additions & 51 deletions
This file was deleted.

include/elementwise_univariate.h

Lines changed: 0 additions & 54 deletions
This file was deleted.

include/expr.h

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,24 @@ typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double
3939
typedef bool (*is_affine_fn)(const struct expr *node);
4040
typedef void (*free_type_data_fn)(struct expr *node);
4141

42-
/* Base expression node structure - contains only common fields */
42+
/* Workspace for derivative computation */
43+
typedef struct
44+
{
45+
double *dwork;
46+
int *iwork;
47+
CSC_Matrix *jacobian_csc;
48+
int *csc_work; /* for CSR-CSC conversion */
49+
50+
/* jacobian_csc_filled is only used for affine functions to avoid redundant
51+
conversions. Could become relevant for non-affine functions if we start
52+
supporting common subexpressions on the Python side. */
53+
bool jacobian_csc_filled;
54+
double *local_jac_diag; /* cached f'(g(x)) diagonal */
55+
CSR_Matrix *hess_term1; /* Jg^T D Jg workspace */
56+
CSR_Matrix *hess_term2; /* child wsum_hess workspace */
57+
} Expr_Work;
58+
59+
/* Base expression node structure */
4360
typedef struct expr
4461
{
4562
// ------------------------------------------------------------------------
@@ -48,8 +65,6 @@ typedef struct expr
4865
int d1, d2, size, n_vars, refcount, var_id;
4966
struct expr *left;
5067
struct expr *right;
51-
double *dwork;
52-
int *iwork;
5368

5469
// ------------------------------------------------------------------------
5570
// oracle related quantities
@@ -58,8 +73,8 @@ typedef struct expr
5873
CSR_Matrix *jacobian;
5974
CSR_Matrix *wsum_hess;
6075
forward_fn forward;
61-
jacobian_init_fn jacobian_init;
62-
wsum_hess_init_fn wsum_hess_init;
76+
jacobian_init_fn jacobian_init_impl;
77+
wsum_hess_init_fn wsum_hess_init_impl;
6378
eval_jacobian_fn eval_jacobian;
6479
wsum_hess_fn eval_wsum_hess;
6580

@@ -70,6 +85,13 @@ typedef struct expr
7085
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
7186
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
7287
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
88+
Expr_Work *work; /* derivative workspace */
89+
/* Set to true on all nodes by problem_update_params() via
90+
expr_set_needs_refresh(). Atoms that cache parameter data
91+
(e.g. left_matmul_dense) check this flag before their forward
92+
pass: if true, they refresh their cached matrices from
93+
param_source->value and clear the flag to false. */
94+
bool needs_parameter_refresh;
7395

7496
// name of node just for debugging - should be removed later
7597
char name[32];
@@ -83,6 +105,18 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
83105

84106
void free_expr(expr *node);
85107

108+
/* Guarded init: skips if already initialized (safe for DAGs
109+
* where a node may be visited through multiple parents). */
110+
void jacobian_init(expr *node);
111+
void wsum_hess_init(expr *node);
112+
113+
/* Initialize CSC form of the Jacobian from the CSR Jacobian.
114+
* Must be called after jacobian_init. */
115+
void jacobian_csc_init(expr *node);
116+
117+
/* Recursively set needs_parameter_refresh on node and all children */
118+
void expr_set_needs_refresh(expr *node);
119+
86120
/* Reference counting helpers */
87121
void expr_retain(expr *node);
88122

0 commit comments

Comments
 (0)